Coverage for ibllib/io/extractors/ephys_fpga.py: 92%
565 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 14:26 +0100
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 14:26 +0100
1"""Data extraction from raw FPGA output.
3The behaviour extraction happens in the following stages:
5 1. The NI DAQ events are extracted into a map of event times and TTL polarities.
6 2. The Bpod trial events are extracted from the raw Bpod data, depending on the task protocol.
7 3. As protocols may be chained together within a given recording, the period of a given task
8 protocol is determined using the 'spacer' DAQ signal (see `get_protocol_period`).
9 4. Physical behaviour events such as stim on and reward time are separated out by TTL length or
10 sequence within the trial.
11 5. The Bpod clock is sync'd with the FPGA using one of the extracted trial events.
12 6. The Bpod software events are then converted to FPGA time.
14Examples
15--------
16For simple extraction, use the FPGATrials class:
18>>> extractor = FpgaTrials(session_path)
19>>> trials, _ = extractor.extract(update=False, save=False)
21Notes
22-----
23Sync extraction in this module only supports FPGA data acquired with an NI DAQ as part of a
24Neuropixels recording system, however a sync and channel map extracted from a different DAQ format
25can be passed to the FpgaTrials class.
27See Also
28--------
29For dynamic pipeline sessions it is best to call the extractor via the BehaviorTask class.
31TODO notes on subclassing various methods of FpgaTrials for custom hardware.
32"""
33import logging
34from itertools import cycle
35from pathlib import Path
36import uuid
37import re
39import matplotlib.pyplot as plt
40from matplotlib.colors import TABLEAU_COLORS
41import numpy as np
42from packaging import version
44import spikeglx
45import ibldsp.utils
46import one.alf.io as alfio
47from one.alf.path import filename_parts
48from iblutil.util import Bunch
49from iblutil.spacer import Spacer
51import ibllib.exceptions as err
52from ibllib.io import raw_data_loaders as raw, session_params
53from ibllib.io.extractors.bpod_trials import get_bpod_extractor
54import ibllib.io.extractors.base as extractors_base
55from ibllib.io.extractors.training_wheel import extract_wheel_moves
56from ibllib import plots
57from ibllib.io.extractors.default_channel_maps import DEFAULT_MAPS
59_logger = logging.getLogger(__name__)
61SYNC_BATCH_SIZE_SECS = 100
62"""int: Number of samples to read at once in bin file for sync."""
64WHEEL_RADIUS_CM = 1 # stay in radians
65"""float: The radius of the wheel used in the task. A value of 1 ensures units remain in radians."""
67WHEEL_TICKS = 1024
68"""int: The number of encoder pulses per channel for one complete rotation."""
70BPOD_FPGA_DRIFT_THRESHOLD_PPM = 150
71"""int: Logs a warning if Bpod to FPGA clock drift is higher than this value."""
73CHMAPS = {'3A':
74 {'ap':
75 {'left_camera': 2,
76 'right_camera': 3,
77 'body_camera': 4,
78 'bpod': 7,
79 'frame2ttl': 12,
80 'rotary_encoder_0': 13,
81 'rotary_encoder_1': 14,
82 'audio': 15
83 }
84 },
85 '3B':
86 {'nidq':
87 {'left_camera': 0,
88 'right_camera': 1,
89 'body_camera': 2,
90 'imec_sync': 3,
91 'frame2ttl': 4,
92 'rotary_encoder_0': 5,
93 'rotary_encoder_1': 6,
94 'audio': 7,
95 'bpod': 16,
96 'laser': 17,
97 'laser_ttl': 18},
98 'ap':
99 {'imec_sync': 6}
100 },
101 }
102"""dict: The default channel indices corresponding to various devices for different recording systems."""
105def data_for_keys(keys, data):
106 """Check keys exist in 'data' dict and contain values other than None."""
107 return data is not None and all(k in data and data.get(k, None) is not None for k in keys) 1fNROPDEQuZSebmnold
110def get_ibl_sync_map(ef, version):
111 """
112 Gets default channel map for the version/binary file type combination
113 :param ef: ibllib.io.spikeglx.glob_ephys_file dictionary with field 'ap' or 'nidq'
114 :return: channel map dictionary
115 """
116 # Determine default channel map
117 if version == '3A': 10JKFGfqrsNOPDEQuHIZSaecgbmnolkdh
118 default_chmap = CHMAPS['3A']['ap'] 10Zaecmnod
119 elif version == '3B': 10JKFGfqrsNOPDEQuHISagblkdh
120 if ef.get('nidq', None): 10JKFGfqrsNOPDEQuHISagblkdh
121 default_chmap = CHMAPS['3B']['nidq'] 10JKFGfqrsNOPDEQuHISagblkdh
122 elif ef.get('ap', None): 10qrsSlk
123 default_chmap = CHMAPS['3B']['ap'] 10qrsSlk
124 # Try to load channel map from file
125 chmap = spikeglx.get_sync_map(ef['path']) 10JKFGfqrsNOPDEQuHIZSaecgbmnolkdh
126 # If chmap provided but not with all keys, fill up with default values
127 if not chmap: 10JKFGfqrsNOPDEQuHIZSaecgbmnolkdh
128 return default_chmap 10JKFGqrsHIZSacglkh
129 else:
130 if data_for_keys(default_chmap.keys(), chmap): 1fNOPDEQuZSebmnold
131 return chmap 1fuZSebmnold
132 else:
133 _logger.warning("Keys missing from provided channel map, " 1NOPDEQZmnd
134 "setting missing keys from default channel map")
135 return {**default_chmap, **chmap} 1NOPDEQZmnd
138def _sync_to_alf(raw_ephys_apfile, output_path=None, save=False, parts=''):
139 """
140 Extracts sync.times, sync.channels and sync.polarities from binary ephys dataset
142 :param raw_ephys_apfile: bin file containing ephys data or spike
143 :param output_path: output directory
144 :param save: bool write to disk only if True
145 :param parts: string or list of strings that will be appended to the filename before extension
146 :return:
147 """
148 # handles input argument: support ibllib.io.spikeglx.Reader, str and pathlib.Path
149 if isinstance(raw_ephys_apfile, spikeglx.Reader): 1xyzqrsABCk
150 sr = raw_ephys_apfile 1xyzqrsABCk
151 else:
152 raw_ephys_apfile = Path(raw_ephys_apfile)
153 sr = spikeglx.Reader(raw_ephys_apfile)
154 if not (opened := sr.is_open): 1xyzqrsABCk
155 sr.open()
156 # if no output, need a temp folder to swap for big files
157 if not output_path: 1xyzqrsABCk
158 output_path = raw_ephys_apfile.parent
159 file_ftcp = Path(output_path).joinpath(f'fronts_times_channel_polarity{uuid.uuid4()}.bin') 1xyzqrsABCk
161 # loop over chunks of the raw ephys file
162 wg = ibldsp.utils.WindowGenerator(sr.ns, int(SYNC_BATCH_SIZE_SECS * sr.fs), overlap=1) 1xyzqrsABCk
163 fid_ftcp = open(file_ftcp, 'wb') 1xyzqrsABCk
164 for sl in wg.slice: 1xyzqrsABCk
165 ss = sr.read_sync(sl) 1xyzqrsABCk
166 ind, fronts = ibldsp.utils.fronts(ss, axis=0) 1xyzqrsABCk
167 # a = sr.read_sync_analog(sl)
168 sav = np.c_[(ind[0, :] + sl.start) / sr.fs, ind[1, :], fronts.astype(np.double)] 1xyzqrsABCk
169 sav.tofile(fid_ftcp) 1xyzqrsABCk
170 # close temp file, read from it and delete
171 fid_ftcp.close() 1xyzqrsABCk
172 tim_chan_pol = np.fromfile(str(file_ftcp)) 1xyzqrsABCk
173 tim_chan_pol = tim_chan_pol.reshape((int(tim_chan_pol.size / 3), 3)) 1xyzqrsABCk
174 file_ftcp.unlink() 1xyzqrsABCk
175 sync = {'times': tim_chan_pol[:, 0], 1xyzqrsABCk
176 'channels': tim_chan_pol[:, 1],
177 'polarities': tim_chan_pol[:, 2]}
178 # If opened Reader was passed into function, leave open
179 if not opened: 1xyzqrsABCk
180 sr.close()
181 if save: 1xyzqrsABCk
182 out_files = alfio.save_object_npy(output_path, sync, 'sync', 1xyzqrsABCk
183 namespace='spikeglx', parts=parts)
184 return Bunch(sync), out_files 1xyzqrsABCk
185 else:
186 return Bunch(sync)
189def _rotary_encoder_positions_from_fronts(ta, pa, tb, pb, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4'):
190 """
191 Extracts the rotary encoder absolute position as function of time from fronts detected
192 on the 2 channels. Outputs in units of radius parameters, by default radians
193 Coding options detailed here: http://www.ni.com/tutorial/7109/pt/
194 Here output is clockwise from subject perspective
196 :param ta: time of fronts on channel A
197 :param pa: polarity of fronts on channel A
198 :param tb: time of fronts on channel B
199 :param pb: polarity of fronts on channel B
200 :param ticks: number of ticks corresponding to a full revolution (1024 for IBL rotary encoder)
201 :param radius: radius of the wheel. Defaults to 1 for an output in radians
202 :param coding: x1, x2 or x4 coding (IBL default is x4)
203 :return: indices vector (ta) and position vector
204 """
205 if coding == 'x1': 17W81Lfaecgbmnolkdh
206 ia = np.searchsorted(tb, ta[pa == 1]) 1WL
207 ia = ia[ia < ta.size] 1WL
208 ia = ia[pa[ia] == 1] 1WL
209 ib = np.searchsorted(ta, tb[pb == 1]) 1WL
210 ib = ib[ib < tb.size] 1WL
211 ib = ib[pb[ib] == 1] 1WL
212 t = np.r_[ta[ia], tb[ib]] 1WL
213 p = np.r_[ia * 0 + 1, ib * 0 - 1] 1WL
214 ordre = np.argsort(t) 1WL
215 t = t[ordre] 1WL
216 p = p[ordre] 1WL
217 p = np.cumsum(p) / ticks * np.pi * 2 * radius 1WL
218 return t, p 1WL
219 elif coding == 'x2': 1781Lfaecgbmnolkdh
220 p = pb[np.searchsorted(tb, ta) - 1] * pa 178L
221 p = - np.cumsum(p) / ticks * np.pi * 2 * radius / 2 178L
222 return ta, p 178L
223 elif coding == 'x4': 11Lfaecgbmnolkdh
224 p = np.r_[pb[np.searchsorted(tb, ta) - 1] * pa, -pa[np.searchsorted(ta, tb) - 1] * pb] 11Lfaecgbmnolkdh
225 t = np.r_[ta, tb] 11Lfaecgbmnolkdh
226 ordre = np.argsort(t) 11Lfaecgbmnolkdh
227 t = t[ordre] 11Lfaecgbmnolkdh
228 p = p[ordre] 11Lfaecgbmnolkdh
229 p = - np.cumsum(p) / ticks * np.pi * 2 * radius / 4 11Lfaecgbmnolkdh
230 return t, p 11Lfaecgbmnolkdh
233def _assign_events_to_trial(t_trial_start, t_event, take='last', t_trial_end=None):
234 """
235 Assign events to a trial given trial start times and event times.
237 Trials without an event result in nan value in output time vector.
238 The output has a consistent size with t_trial_start and ready to output to alf.
240 Parameters
241 ----------
242 t_trial_start : numpy.array
243 An array of start times, used to bin edges for assigning values from `t_event`.
244 t_event : numpy.array
245 An array of event times to assign to trials.
246 take : str {'first', 'last'}, int
247 'first' takes first event > t_trial_start; 'last' takes last event < the next
248 t_trial_start; an int defines the index to take for events within trial bounds. The index
249 may be negative.
250 t_trial_end : numpy.array
251 Optional array of end times, used to bin edges for assigning values from `t_event`.
253 Returns
254 -------
255 numpy.array
256 An array the length of `t_trial_start` containing values from `t_event`. Unassigned values
257 are replaced with np.nan.
259 See Also
260 --------
261 FpgaTrials._assign_events - Assign trial events based on TTL length.
262 """
263 # make sure the events are sorted
264 try: 1Mfjiaecgbdh
265 assert np.all(np.diff(t_trial_start) >= 0) 1Mfjiaecgbdh
266 except AssertionError: 1M
267 raise ValueError('Trial starts vector not sorted') 1M
268 try: 1Mfjiaecgbdh
269 assert np.all(np.diff(t_event) >= 0) 1Mfjiaecgbdh
270 except AssertionError: 1M
271 raise ValueError('Events vector is not sorted') 1M
273 # remove events that happened before the first trial start
274 remove = t_event < t_trial_start[0] 1Mfjiaecgbdh
275 if t_trial_end is not None: 1Mfjiaecgbdh
276 if not np.all(np.diff(t_trial_end) >= 0): 1fjiaecgbdh
277 raise ValueError('Trial end vector not sorted')
278 if not np.all(t_trial_end[:-1] < t_trial_start[1:]): 1fjiaecgbdh
279 raise ValueError('Trial end times must not overlap with trial start times')
280 # remove events between end and next start, and after last end
281 remove |= t_event > t_trial_end[-1] 1fjiaecgbdh
282 for e, s in zip(t_trial_end[:-1], t_trial_start[1:]): 1fjiaecgbdh
283 remove |= np.logical_and(s > t_event, t_event >= e) 1fjiaecgbdh
284 t_event = t_event[~remove] 1Mfjiaecgbdh
285 ind = np.searchsorted(t_trial_start, t_event) - 1 1Mfjiaecgbdh
286 t_event_nans = np.zeros_like(t_trial_start) * np.nan 1Mfjiaecgbdh
287 # select first or last element matching each trial start
288 if take == 'last': 1Mfjiaecgbdh
289 iall, iu = np.unique(np.flip(ind), return_index=True) 1Mfjiaecgbdh
290 t_event_nans[iall] = t_event[- (iu - ind.size + 1)] 1Mfjiaecgbdh
291 elif take == 'first': 1Mfjiaecgbdh
292 iall, iu = np.unique(ind, return_index=True) 1Mfjiaecgbdh
293 t_event_nans[iall] = t_event[iu] 1Mfjiaecgbdh
294 else: # if the index is arbitrary, needs to be numeric (could be negative if from the end)
295 iall = np.unique(ind) 1M
296 minsize = take + 1 if take >= 0 else - take 1M
297 # for each trial, take the take nth element if there are enough values in trial
298 for iu in iall: 1M
299 match = t_event[iu == ind] 1M
300 if len(match) >= minsize: 1M
301 t_event_nans[iu] = match[take] 1M
302 return t_event_nans 1Mfjiaecgbdh
305def get_sync_fronts(sync, channel_nb, tmin=None, tmax=None):
306 """
307 Return the sync front polarities and times for a given channel.
309 Parameters
310 ----------
311 sync : dict
312 'polarities' of fronts detected on sync trace for all 16 channels and their 'times'.
313 channel_nb : int
314 The integer corresponding to the desired sync channel.
315 tmin : float
316 The minimum time from which to extract the sync pulses.
317 tmax : float
318 The maximum time up to which we extract the sync pulses.
320 Returns
321 -------
322 Bunch
323 Channel times and polarities.
324 """
325 selection = sync['channels'] == channel_nb 1wvtJKFGfqrsjiNRTUVOPDEQuHISaecgbmnolkdh
326 selection = np.logical_and(selection, sync['times'] <= tmax) if tmax else selection 1wvtJKFGfqrsjiNRTUVOPDEQuHISaecgbmnolkdh
327 selection = np.logical_and(selection, sync['times'] >= tmin) if tmin else selection 1wvtJKFGfqrsjiNRTUVOPDEQuHISaecgbmnolkdh
328 return Bunch({'times': sync['times'][selection], 1wvtJKFGfqrsjiNRTUVOPDEQuHISaecgbmnolkdh
329 'polarities': sync['polarities'][selection]})
332def _clean_audio(audio, display=False):
333 """
334 one guy wired the 150 Hz camera output onto the soundcard. The effect is to get 150 Hz periodic
335 square pulses, 2ms up and 4.666 ms down. When this happens we remove all of the intermediate
336 pulses to repair the audio trace
337 Here is some helper code
338 dd = np.diff(audio['times'])
339 1 / np.median(dd[::2]) # 2ms up
340 1 / np.median(dd[1::2]) # 4.666 ms down
341 1 / (np.median(dd[::2]) + np.median(dd[1::2])) # both sum to 150 Hz
342 This only runs on sessions when the bug is detected and leaves others untouched
343 """
344 DISCARD_THRESHOLD = 0.01 1wYfjiaecgbdh
345 average_150_hz = np.mean(1 / np.diff(audio['times'][audio['polarities'] == 1]) > 140) 1wYfjiaecgbdh
346 naudio = audio['times'].size 1wYfjiaecgbdh
347 if average_150_hz > 0.7 and naudio > 100: 1wYfjiaecgbdh
348 _logger.warning('Soundcard signal on FPGA seems to have been mixed with 150Hz camera') 1Y
349 keep_ind = np.r_[np.diff(audio['times']) > DISCARD_THRESHOLD, False] 1Y
350 keep_ind = np.logical_and(keep_ind, audio['polarities'] == -1) 1Y
351 keep_ind = np.where(keep_ind)[0] 1Y
352 keep_ind = np.sort(np.r_[0, keep_ind, keep_ind + 1, naudio - 1]) 1Y
354 if display: # pragma: no cover 1Y
355 from ibllib.plots import squares
356 squares(audio['times'], audio['polarities'], ax=None, yrange=[-1, 1])
357 squares(audio['times'][keep_ind], audio['polarities'][keep_ind], yrange=[-1, 1])
358 audio = {'times': audio['times'][keep_ind], 1Y
359 'polarities': audio['polarities'][keep_ind]}
360 return audio 1wYfjiaecgbdh
363def _clean_frame2ttl(frame2ttl, threshold=0.01, display=False):
364 """
365 Clean the frame2ttl events.
367 Frame 2ttl calibration can be unstable and the fronts may be flickering at an unrealistic
368 pace. This removes the consecutive frame2ttl pulses happening too fast, below a threshold
369 of F2TTL_THRESH.
371 Parameters
372 ----------
373 frame2ttl : dict
374 A dictionary of frame2TTL events, with keys {'times', 'polarities'}.
375 threshold : float
376 Consecutive pulses occurring with this many seconds ignored.
377 display : bool
378 If true, plots the input TTLs and the cleaned output.
380 Returns
381 -------
383 """
384 dt = np.diff(frame2ttl['times']) 1p3FGf4jiuHIaecgb56dh
385 iko = np.where(np.logical_and(dt < threshold, frame2ttl['polarities'][:-1] == -1))[0] 1p3FGf4jiuHIaecgb56dh
386 iko = np.unique(np.r_[iko, iko + 1]) 1p3FGf4jiuHIaecgb56dh
387 frame2ttl_ = {'times': np.delete(frame2ttl['times'], iko), 1p3FGf4jiuHIaecgb56dh
388 'polarities': np.delete(frame2ttl['polarities'], iko)}
389 if iko.size > (0.1 * frame2ttl['times'].size): 1p3FGf4jiuHIaecgb56dh
390 _logger.warning(f'{iko.size} ({iko.size / frame2ttl["times"].size:.2%}) ' 13je
391 f'frame to TTL polarity switches below {threshold} secs')
392 if display: # pragma: no cover 1p3FGf4jiuHIaecgb56dh
393 fig, (ax0, ax1) = plt.subplots(2, sharex=True)
394 plots.squares(frame2ttl['times'] * 1000, frame2ttl['polarities'], yrange=[0.1, 0.9], ax=ax0)
395 plots.squares(frame2ttl_['times'] * 1000, frame2ttl_['polarities'], yrange=[1.1, 1.9], ax=ax1)
396 import seaborn as sns
397 sns.displot(dt[dt < 0.05], binwidth=0.0005)
399 return frame2ttl_ 1p3FGf4jiuHIaecgb56dh
402def extract_wheel_sync(sync, chmap=None, tmin=None, tmax=None):
403 """
404 Extract wheel positions and times from sync fronts dictionary for all 16 channels.
405 Output position is in radians, mathematical convention.
407 Parameters
408 ----------
409 sync : dict
410 'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
411 chmap : dict
412 Map of channel names and their corresponding index. Default to constant.
413 tmin : float
414 The minimum time from which to extract the sync pulses.
415 tmax : float
416 The maximum time up to which we extract the sync pulses.
418 Returns
419 -------
420 numpy.array
421 Wheel timestamps in seconds.
422 numpy.array
423 Wheel positions in radians.
424 """
425 # Assume two separate edge count channels
426 assert chmap.keys() >= {'rotary_encoder_0', 'rotary_encoder_1'} 1faecgbmnolkdh
427 channela = get_sync_fronts(sync, chmap['rotary_encoder_0'], tmin=tmin, tmax=tmax) 1faecgbmnolkdh
428 channelb = get_sync_fronts(sync, chmap['rotary_encoder_1'], tmin=tmin, tmax=tmax) 1faecgbmnolkdh
429 re_ts, re_pos = _rotary_encoder_positions_from_fronts( 1faecgbmnolkdh
430 channela['times'], channela['polarities'], channelb['times'], channelb['polarities'],
431 ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4')
432 return re_ts, re_pos 1faecgbmnolkdh
435def extract_sync(session_path, overwrite=False, ephys_files=None, namespace='spikeglx'):
436 """
437 Reads ephys binary file (s) and extract sync within the binary file folder
438 Assumes ephys data is within a `raw_ephys_data` folder
440 :param session_path: '/path/to/subject/yyyy-mm-dd/001'
441 :param overwrite: Bool on re-extraction, forces overwrite instead of loading existing files
442 :return: list of sync dictionaries
443 """
444 session_path = Path(session_path) 1xyzqrsABCXmnolk
445 if not ephys_files: 1xyzqrsABCXmnolk
446 ephys_files = spikeglx.glob_ephys_files(session_path) 1xyzXmnolk
447 syncs = [] 1xyzqrsABCXmnolk
448 outputs = [] 1xyzqrsABCXmnolk
449 for efi in ephys_files: 1xyzqrsABCXmnolk
450 bin_file = efi.get('ap', efi.get('nidq', None)) 1xyzqrsABCmnolk
451 if not bin_file: 1xyzqrsABCmnolk
452 continue
453 alfname = dict(object='sync', namespace=namespace) 1xyzqrsABCmnolk
454 if efi.label: 1xyzqrsABCmnolk
455 alfname['extra'] = efi.label 1qrsmnolk
456 file_exists = alfio.exists(bin_file.parent, **alfname) 1xyzqrsABCmnolk
457 if not overwrite and file_exists: 1xyzqrsABCmnolk
458 _logger.warning(f'Skipping raw sync: SGLX sync found for {efi.label}!') 1xyzmnol
459 sync = alfio.load_object(bin_file.parent, **alfname) 1xyzmnol
460 out_files, _ = alfio._ls(bin_file.parent, **alfname) 1xyzmnol
461 else:
462 sr = spikeglx.Reader(bin_file) 1xyzqrsABCk
463 sync, out_files = _sync_to_alf(sr, bin_file.parent, save=True, parts=efi.label) 1xyzqrsABCk
464 sr.close() 1xyzqrsABCk
465 outputs.extend(out_files) 1xyzqrsABCmnolk
466 syncs.extend([sync]) 1xyzqrsABCmnolk
468 return syncs, outputs 1xyzqrsABCXmnolk
471def _get_all_probes_sync(session_path, bin_exists=True):
472 # round-up of all bin ephys files in the session, infer revision and get sync map
473 ephys_files = spikeglx.glob_ephys_files(session_path, bin_exists=bin_exists) 1JKDEaecXmnolkd
474 version = spikeglx.get_neuropixel_version_from_files(ephys_files) 1JKDEaecXmnolkd
475 # attach the sync information to each binary file found
476 for ef in ephys_files: 1JKDEaecXmnolkd
477 ef['sync'] = alfio.load_object(ef.path, 'sync', namespace='spikeglx', short_keys=True) 1JKDEaecmnolkd
478 ef['sync_map'] = get_ibl_sync_map(ef, version) 1JKDEaecmnolkd
479 return ephys_files 1JKDEaecXmnolkd
482def get_wheel_positions(sync, chmap, tmin=None, tmax=None):
483 """
484 Gets the wheel position from synchronisation pulses
486 Parameters
487 ----------
488 sync : dict
489 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and
490 the corresponding channel numbers.
491 chmap : dict[str, int]
492 A map of channel names and their corresponding indices.
493 tmin : float
494 The minimum time from which to extract the sync pulses.
495 tmax : float
496 The maximum time up to which we extract the sync pulses.
498 Returns
499 -------
500 Bunch
501 A dictionary with keys ('timestamps', 'position'), containing the wheel event timestamps and
502 position in radians
503 Bunch
504 A dictionary of detected movement times with keys ('intervals', 'peakAmplitude', 'peakVelocity_times').
505 """
506 ts, pos = extract_wheel_sync(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) 1faecgbdh
507 moves = Bunch(extract_wheel_moves(ts, pos)) 1faecgbdh
508 wheel = Bunch({'timestamps': ts, 'position': pos}) 1faecgbdh
509 return wheel, moves 1faecgbdh
512def get_main_probe_sync(session_path, bin_exists=False):
513 """
514 From 3A or 3B multiprobe session, returns the main probe (3A) or nidq sync pulses
515 with the attached channel map (default chmap if none)
517 Parameters
518 ----------
519 session_path : str, pathlib.Path
520 The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
521 bin_exists : bool
522 Whether there is a .bin file present.
524 Returns
525 -------
526 one.alf.io.AlfBunch
527 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and
528 the corresponding channel numbers.
529 dict
530 A map of channel names and their corresponding indices.
531 """
532 ephys_files = _get_all_probes_sync(session_path, bin_exists=bin_exists) 1JKDEaecXmnolkd
533 if not ephys_files: 1JKDEaecXmnolkd
534 raise FileNotFoundError(f"No ephys files found in {session_path}") 1X
535 version = spikeglx.get_neuropixel_version_from_files(ephys_files) 1JKDEaecmnolkd
536 if version == '3A': 1JKDEaecmnolkd
537 # the sync master is the probe with the most sync pulses
538 sync_box_ind = np.argmax([ef.sync.times.size for ef in ephys_files]) 1aecmnod
539 elif version == '3B': 1JKDElk
540 # the sync master is the nidq breakout box
541 sync_box_ind = np.argmax([1 if ef.get('nidq') else 0 for ef in ephys_files]) 1JKDElk
542 sync = ephys_files[sync_box_ind].sync 1JKDEaecmnolkd
543 sync_chmap = ephys_files[sync_box_ind].sync_map 1JKDEaecmnolkd
544 return sync, sync_chmap 1JKDEaecmnolkd
547def get_protocol_period(session_path, protocol_number, bpod_sync, exclude_empty_periods=True):
548 """
549 Return the start and end time of the protocol number.
551 Note that the start time is the start of the spacer pulses and the end time is either None
552 if the protocol is the final one, or the start of the next spacer.
554 Parameters
555 ----------
556 session_path : str, pathlib.Path
557 The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
558 protocol_number : int
559 The order that the protocol was run in, counted from 0.
560 bpod_sync : dict
561 The sync times and polarities for Bpod BNC1.
562 exclude_empty_periods : bool
563 When true, spacers are ignored if no bpod pulses are detected between periods.
565 Returns
566 -------
567 float
568 The time of the detected spacer for the protocol number.
569 float, None
570 The time of the next detected spacer or None if this is the last protocol run.
571 """
572 # The spacers are TTLs generated by Bpod at the start of each protocol
573 sp = Spacer() 1iu
574 spacer_times = sp.find_spacers_from_fronts(bpod_sync) 1iu
575 if exclude_empty_periods: 1iu
576 # Drop dud protocol spacers (those without any bpod pulses after the spacer)
577 spacer_length = len(sp.generate_template(fs=1000)) / 1000 1iu
578 periods = np.c_[spacer_times + spacer_length, np.r_[spacer_times[1:], np.inf]] 1iu
579 valid = [np.any((bpod_sync['times'] > pp[0]) & (bpod_sync['times'] < pp[1])) for pp in periods] 1iu
580 spacer_times = spacer_times[valid] 1iu
581 # Ensure that the number of detected spacers matched the number of expected tasks
582 if acquisition_description := session_params.read_params(session_path): 1iu
583 n_tasks = len(acquisition_description.get('tasks', [])) 1iu
584 assert len(spacer_times) >= protocol_number, (f'expected {n_tasks} spacers, found only {len(spacer_times)} - ' 1iu
585 f'can not return protocol number {protocol_number}.')
586 assert n_tasks > protocol_number >= 0, f'protocol number must be between 0 and {n_tasks}' 1iu
587 else:
588 assert protocol_number < len(spacer_times)
589 start = spacer_times[int(protocol_number)] 1iu
590 end = None if len(spacer_times) - 1 == protocol_number else spacer_times[int(protocol_number + 1)] 1iu
591 return start, end 1iu
594class FpgaTrials(extractors_base.BaseExtractor):
595 save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy',
596 '_ibl_trials.stimOffTrigger_times.npy', None, None, None, None, None,
597 '_ibl_trials.stimOff_times.npy', None, None, None, '_ibl_trials.quiescencePeriod.npy',
598 '_ibl_trials.table.pqt', '_ibl_wheel.timestamps.npy',
599 '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy',
600 '_ibl_wheelMoves.peakAmplitude.npy', None)
601 var_names = ('goCueTrigger_times', 'stimOnTrigger_times',
602 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times',
603 'errorCue_times', 'itiIn_times', 'stimFreeze_times', 'stimOff_times',
604 'valveOpen_times', 'phase', 'position', 'quiescence', 'table',
605 'wheel_timestamps', 'wheel_position',
606 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times')
608 bpod_rsync_fields = ('intervals', 'response_times', 'goCueTrigger_times',
609 'stimOnTrigger_times', 'stimOffTrigger_times',
610 'stimFreezeTrigger_times', 'errorCueTrigger_times')
611 """tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA."""
613 bpod_fields = ('feedbackType', 'choice', 'rewardVolume', 'contrastLeft', 'contrastRight',
614 'probabilityLeft', 'phase', 'position', 'quiescence')
615 """tuple of str: Fields from bpod extractor that we want to save."""
617 sync_field = 'intervals_0' # trial start events
618 """str: The trial event to synchronize (must be present in extracted trials)."""
620 bpod = None
621 """dict of numpy.array: The Bpod out TTLs recorded on the DAQ. Used in the QC viewer plot."""
623 def __init__(self, *args, bpod_trials=None, bpod_extractor=None, **kwargs):
624 """An extractor for ephysChoiceWorld trials data, in FPGA time.
626 This class may be subclassed to handle moderate variations in hardware and task protocol,
627 however there is flexible
628 """
629 super().__init__(*args, **kwargs) 12wvtf9!jiaecgbmnolkdh
630 self.bpod2fpga = None 12wvtf9!jiaecgbmnolkdh
631 self.bpod_trials = bpod_trials 12wvtf9!jiaecgbmnolkdh
632 self.frame2ttl = self.audio = self.bpod = self.settings = None 12wvtf9!jiaecgbmnolkdh
633 if bpod_extractor: 12wvtf9!jiaecgbmnolkdh
634 self.bpod_extractor = bpod_extractor 1fjiaecgbdh
635 self._update_var_names() 1fjiaecgbdh
637 def _update_var_names(self, bpod_fields=None, bpod_rsync_fields=None):
638 """
639 Updates this object's attributes based on the Bpod trials extractor.
641 Fields updated: bpod_fields, bpod_rsync_fields, save_names, and var_names.
643 Parameters
644 ----------
645 bpod_fields : tuple
646 A set of Bpod trials fields to keep.
647 bpod_rsync_fields : tuple
648 A set of Bpod trials fields to sync to the DAQ times.
649 """
650 if self.bpod_extractor: 1fjiaecgbdh
651 for var_name, save_name in zip(self.bpod_extractor.var_names, self.bpod_extractor.save_names): 1fjiaecgbdh
652 if var_name not in self.var_names: 1fjiaecgbdh
653 self.var_names += (var_name,) 1fjaecgbdh
654 self.save_names += (save_name,) 1fjaecgbdh
656 # self.var_names = self.bpod_extractor.var_names
657 # self.save_names = self.bpod_extractor.save_names
658 self.settings = self.bpod_extractor.settings # This is used by the TaskQC 1fjiaecgbdh
659 self.bpod_rsync_fields = bpod_rsync_fields 1fjiaecgbdh
660 if self.bpod_rsync_fields is None: 1fjiaecgbdh
661 self.bpod_rsync_fields = tuple(self._time_fields(self.bpod_extractor.var_names)) 1fjiaecgbdh
662 if 'table' in self.bpod_extractor.var_names: 1fjiaecgbdh
663 if not self.bpod_trials: 1fjaecgbdh
664 self.bpod_trials = self.bpod_extractor.extract(save=False)
665 table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() 1fjaecgbdh
666 self.bpod_rsync_fields += tuple(self._time_fields(table_keys)) 1fjaecgbdh
667 elif bpod_rsync_fields:
668 self.bpod_rsync_fields = bpod_rsync_fields
669 excluded = (*self.bpod_rsync_fields, 'table') 1fjiaecgbdh
670 if bpod_fields: 1fjiaecgbdh
671 assert not set(self.bpod_fields).intersection(excluded), 'bpod_fields must not also be bpod_rsync_fields'
672 self.bpod_fields = bpod_fields
673 elif self.bpod_extractor: 1fjiaecgbdh
674 self.bpod_fields = tuple(x for x in self.bpod_extractor.var_names if x not in excluded) 1fjiaecgbdh
675 if 'table' in self.bpod_extractor.var_names: 1fjiaecgbdh
676 if not self.bpod_trials: 1fjaecgbdh
677 self.bpod_trials = self.bpod_extractor.extract(save=False)
678 table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() 1fjaecgbdh
679 self.bpod_fields += tuple([x for x in table_keys if x not in excluded]) 1fjaecgbdh
681 @staticmethod
682 def _time_fields(trials_attr) -> set:
683 """
684 Iterates over Bpod trials attributes returning those that correspond to times for syncing.
686 Parameters
687 ----------
688 trials_attr : iterable of str
689 The Bpod field names.
691 Returns
692 -------
693 set
694 The field names that contain timestamps.
695 """
696 FIELDS = ('times', 'timestamps', 'intervals') 1#fjiaecgbdh
697 pattern = re.compile(fr'^[_\w]*({"|".join(FIELDS)})[_\w]*$') 1#fjiaecgbdh
698 return set(filter(pattern.match, trials_attr)) 1#fjiaecgbdh
700 def load_sync(self, sync_collection='raw_ephys_data', **kwargs):
701 """Load the DAQ sync and channel map data.
703 This method may be subclassed for novel DAQ systems. The sync must contain the following
704 keys: 'times' - an array timestamps in seconds; 'polarities' - an array of {-1, 1}
705 corresponding to TTL LOW and TTL HIGH, respectively; 'channels' - an array of ints
706 corresponding to channel number.
708 Parameters
709 ----------
710 sync_collection : str
711 The session subdirectory where the sync data are located.
712 kwargs
713 Optional arguments used by subclass methods.
715 Returns
716 -------
717 one.alf.io.AlfBunch
718 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
719 and the corresponding channel numbers.
720 dict
721 A map of channel names and their corresponding indices.
722 """
723 return get_sync_and_chn_map(self.session_path, sync_collection) 1c
725 def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data',
726 task_collection='raw_behavior_data', **kwargs) -> dict:
727 """Extracts ephys trials by combining Bpod and FPGA sync pulses.
729 It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field`
730 attributes are all correct for the bpod protocol used.
732 Below are the steps involved:
733 0. Load sync and bpod trials, if required.
734 1. Determine protocol period and discard sync events outside the task.
735 2. Classify multiplexed TTL events based on length (see :meth:`FpgaTrials.build_trials`).
736 3. Sync the Bpod clock to the DAQ clock using one of the assigned trial events.
737 4. Assign classified TTL events to trial events based on order within the trial.
738 4. Convert Bpod software event times to DAQ clock.
739 5. Extract the wheel from the DAQ rotary encoder signal, if required.
741 Parameters
742 ----------
743 sync : dict
744 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
745 and the corresponding channel numbers. If None, the sync is loaded using the
746 `load_sync` method.
747 chmap : dict
748 A map of channel names and their corresponding indices. If None, the channel map is
749 loaded using the :meth:`FpgaTrials.load_sync` method.
750 sync_collection : str
751 The session subdirectory where the sync data are located. This is only used if the
752 sync or channel maps are not provided.
753 task_collection : str
754 The session subdirectory where the raw Bpod data are located. This is used for loading
755 the task settings and extracting the bpod trials, if not already done.
756 protocol_number : int
757 The protocol number if multiple protocols were run during the session. If provided, a
758 spacer signal must be present in order to determine the correct period.
759 kwargs
760 Optional arguments for subclass methods to use.
762 Returns
763 -------
764 dict
765 A dictionary of numpy arrays with `FpgaTrials.var_names` as keys.
766 """
767 if sync is None or chmap is None: 1fjiaecgbdh
768 _sync, _chmap = self.load_sync(sync_collection) 1ji
769 sync = sync or _sync 1ji
770 chmap = chmap or _chmap 1ji
772 if not self.bpod_trials: # extract the behaviour data from bpod 1fjiaecgbdh
773 self.extractor = get_bpod_extractor(self.session_path, task_collection=task_collection)
774 _logger.info('Bpod trials extractor: %s.%s', self.extractor.__module__, self.extractor.__class__.__name__)
775 self.bpod_trials, *_ = self.extractor.extract(task_collection=task_collection, save=False, **kwargs)
777 # Explode trials table df
778 if 'table' in self.var_names: 1fjiaecgbdh
779 trials_table = alfio.AlfBunch.from_df(self.bpod_trials.pop('table')) 1fjaecgbdh
780 table_columns = trials_table.keys() 1fjaecgbdh
781 self.bpod_trials.update(trials_table) 1fjaecgbdh
782 else:
783 if 'table' in self.bpod_trials: 1ia
784 _logger.error(
785 '"table" found in Bpod trials but missing from `var_names` attribute and will'
786 'therefore not be extracted. This is likely in error.')
787 table_columns = None 1ia
789 bpod = get_sync_fronts(sync, chmap['bpod']) 1fjiaecgbdh
790 # Get the spacer times for this protocol
791 if any(arg in kwargs for arg in ('tmin', 'tmax')): 1fjiaecgbdh
792 tmin, tmax = kwargs.get('tmin'), kwargs.get('tmax') 1cdh
793 elif (protocol_number := kwargs.get('protocol_number')) is not None: # look for spacer 1fjiaegb
794 # The spacers are TTLs generated by Bpod at the start of each protocol
795 tmin, tmax = get_protocol_period(self.session_path, protocol_number, bpod) 1i
796 tmin += (Spacer().times[-1] + Spacer().tup + 0.05) # exclude spacer itself 1i
797 else:
798 # Older sessions don't have protocol spacers so we sync the Bpod intervals here to
799 # find the approximate end time of the protocol (this will exclude the passive signals
800 # in ephysChoiceWorld that tend to ruin the final trial extraction).
801 _, trial_ints = self.get_bpod_event_times(sync, chmap, **kwargs) 1fjaegb
802 t_trial_start = trial_ints.get('trial_start', np.array([[np.nan, np.nan]]))[:, 0] 1fjaegb
803 bpod_start = self.bpod_trials['intervals'][:, 0] 1fjaegb
804 if len(t_trial_start) > len(bpod_start) / 2: # if least half the trial start TTLs detected 1fjaegb
805 _logger.warning('Attempting to get protocol period from aligning trial start TTLs') 1faegb
806 fcn, *_ = ibldsp.utils.sync_timestamps(bpod_start, t_trial_start) 1faegb
807 buffer = 2.5 # the number of seconds to include before/after task 1faegb
808 start, end = fcn(self.bpod_trials['intervals'].flat[[0, -1]]) 1faegb
809 # NB: The following was added by k1o0 in commit b31d14e5113180b50621c985b2f230ba84da1dd3
810 # however it is not clear why this was necessary and it appears to defeat the purpose of
811 # removing the passive protocol part from the final trial extraction in ephysChoiceWorld.
812 # tmin = min(sync['times'][0], start - buffer)
813 # tmax = max(sync['times'][-1], end + buffer)
814 tmin = start - buffer 1faegb
815 tmax = end + buffer 1faegb
816 else: # This type of alignment fails for some sessions, e.g. mesoscope
817 tmin = tmax = None 1ja
819 # Remove unnecessary data from sync
820 selection = np.logical_and( 1fjiaecgbdh
821 sync['times'] <= (tmax if tmax is not None else sync['times'][-1]),
822 sync['times'] >= (tmin if tmin is not None else sync['times'][0]),
823 )
824 sync = alfio.AlfBunch({k: v[selection] for k, v in sync.items()}) 1fjiaecgbdh
825 _logger.debug('Protocol period from %.2fs to %.2fs (~%.0f min duration)', 1fjiaecgbdh
826 *sync['times'][[0, -1]], np.diff(sync['times'][[0, -1]]) / 60)
828 # Get the trial events from the DAQ sync TTLs, sync clocks and build final trials datasets
829 out = self.build_trials(sync=sync, chmap=chmap, **kwargs) 1fjiaecgbdh
831 # extract the wheel data
832 if any(x.startswith('wheel') for x in self.var_names): 1fjiaecgbdh
833 wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) 1fjaecgbdh
834 from ibllib.io.extractors.training_wheel import extract_first_movement_times 1fjaecgbdh
835 if not self.settings: 1fjaecgbdh
836 self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection)
837 min_qt = self.settings.get('QUIESCENT_PERIOD', None) 1fjaecgbdh
838 first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt) 1fjaecgbdh
839 out.update({'firstMovement_times': first_move_onsets}) 1fjaecgbdh
840 out.update({f'wheel_{k}': v for k, v in wheel.items()}) 1fjaecgbdh
841 out.update({f'wheelMoves_{k}': v for k, v in moves.items()}) 1fjaecgbdh
843 # Re-create trials table
844 if table_columns: 1fjiaecgbdh
845 trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns}) 1fjaecgbdh
846 out['table'] = trials_table.to_df() 1fjaecgbdh
848 out = alfio.AlfBunch({k: out[k] for k in self.var_names if k in out}) # Reorder output 1fjiaecgbdh
849 assert self.var_names == tuple(out.keys()) 1fjiaecgbdh
850 return out 1fjiaecgbdh
852 def _is_trials_object_attribute(self, var_name, variable_length_vars=None):
853 """
854 Check if variable name is expected to have the same length as trials.intervals.
856 Parameters
857 ----------
858 var_name : str
859 The variable name to check.
860 variable_length_vars : list
861 Set of variable names that are not expected to have the same length as trials.intervals.
862 This list may be passed by superclasses.
864 Returns
865 -------
866 bool
867 True if variable is a trials dataset.
869 Examples
870 --------
871 >>> assert self._is_trials_object_attribute('stimOnTrigger_times') is True
872 >>> assert self._is_trials_object_attribute('wheel_position') is False
873 """
874 save_name = self.save_names[self.var_names.index(var_name)] if var_name in self.var_names else None 12faecgbdh
875 if save_name: 12faecgbdh
876 return filename_parts(save_name)[1] == 'trials' 12faecgbdh
877 else:
878 return var_name not in (variable_length_vars or []) 12faecgbdh
880 def build_trials(self, sync, chmap, display=False, **kwargs):
881 """
882 Extract task related event times from the sync.
884 The trial start times are the shortest Bpod TTLs and occur at the start of the trial. The
885 first trial start TTL of the session is longer and must be handled differently. The trial
886 start TTL is used to assign the other trial events to each trial.
888 The trial end is the end of the so-called 'ITI' Bpod event TTL (classified as the longest
889 of the three Bpod event TTLs). Go cue audio TTLs are the shorter of the two expected audio
890 tones. The first of these after each trial start is taken to be the go cue time. Error
891 tones are longer audio TTLs and assigned as the last of such occurrence after each trial
892 start. The valve open Bpod TTLs are medium-length, the last of which is used for each trial.
893 The feedback times are times of either valve open or error tone as there should be only one
894 such event per trial.
896 The stimulus times are taken from the frame2ttl events (with improbably high frequency TTLs
897 removed): the first TTL after each trial start is assumed to be the stim onset time; the
898 second to last and last are taken as the stimulus freeze and offset times, respectively.
900 Parameters
901 ----------
902 sync : dict
903 'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
904 chmap : dict
905 Map of channel names and their corresponding index. Default to constant.
906 display : bool, matplotlib.pyplot.Axes
907 Show the full session sync pulses display.
909 Returns
910 -------
911 dict
912 A map of trial event timestamps.
913 """
914 # Get the events from the sync.
915 # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC
916 self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) 1faecgbdh
917 self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) 1faecgbdh
918 if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: 1faecgbdh
919 raise ValueError(
920 'Expected at least "ready_tone" and "error_tone" audio events.'
921 '`audio_event_ttls` kwarg may be incorrect.')
922 self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) 1faecgbdh
923 if not set(bpod_event_intervals.keys()) >= {'trial_start', 'valve_open', 'trial_end'}: 1faecgbdh
924 raise ValueError(
925 'Expected at least "trial_start", "trial_end", and "valve_open" audio events. '
926 '`bpod_event_ttls` kwarg may be incorrect.')
928 t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T 1faecgbdh
929 fpga_events = alfio.AlfBunch({ 1faecgbdh
930 'goCue_times': audio_event_intervals['ready_tone'][:, 0],
931 'errorCue_times': audio_event_intervals['error_tone'][:, 0],
932 'valveOpen_times': bpod_event_intervals['valve_open'][:, 0],
933 'valveClose_times': bpod_event_intervals['valve_open'][:, 1],
934 'itiIn_times': t_iti_in,
935 'intervals_0': bpod_event_intervals['trial_start'][:, 0],
936 'intervals_1': t_trial_end
937 })
939 # Sync the Bpod clock to the DAQ.
940 # NB: The Bpod extractor typically drops the final, incomplete, trial. Hence there is
941 # usually at least one extra FPGA event. This shouldn't affect the sync. The final trial is
942 # dropped after assigning the FPGA events, using the `ibpod` index. Doing this after
943 # assigning the FPGA trial events ensures the last trial has the correct timestamps.
944 self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) 1faecgbdh
946 bpod_start = self.bpod2fpga(self.bpod_trials['intervals'][:, 0]) 1faecgbdh
947 missing_bpod_idx = np.setxor1d(ibpod, np.arange(len(bpod_start))) 1faecgbdh
948 if missing_bpod_idx.size > 0 and self.sync_field == 'intervals_0': 1faecgbdh
949 # One issue is that sometimes pulses may not have been detected, in this case
950 # add the events that have not been detected and re-extract the behaviour sync.
951 # This is only really relevant for the Bpod interval events as the other TTLs are
952 # from devices where a missing TTL likely means the Bpod event was truly absent.
953 _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') 1cb
954 missing_bpod = bpod_start[missing_bpod_idx] 1cb
955 # Another complication: if the first trial start is missing on the FPGA, the second
956 # trial start is assumed to be the first and is mis-assigned to another trial event
957 # (i.e. valve open). This is done because the first Bpod pulse is irregularly long.
958 # See `FpgaTrials.get_bpod_event_times` for details.
960 # If first trial start is missing first detected FPGA event doesn't match any Bpod
961 # starts then it's probably a mis-assigned valve or trial end event.
962 i1 = np.any(missing_bpod_idx == 0) and not np.any(np.isclose(fpga_events['intervals_0'][0], bpod_start)) 1cb
963 # skip mis-assigned first FPGA trial start
964 t_trial_start = np.sort(np.r_[fpga_events['intervals_0'][int(i1):], missing_bpod]) 1cb
965 ibpod = np.sort(np.r_[ibpod, missing_bpod_idx]) 1cb
966 if i1: 1cb
967 # The first trial start is actually the first valve open here
968 first_on, first_off = bpod_event_intervals['trial_start'][0, :] 1cb
969 bpod_valve_open = self.bpod2fpga(self.bpod_trials['feedback_times'][self.bpod_trials['feedbackType'] == 1]) 1cb
970 if np.any(np.isclose(first_on, bpod_valve_open)): 1cb
971 # Probably assigned to the valve open
972 _logger.debug('Re-reassigning first valve open event. TTL length = %.3g ms', first_off - first_on) 1c
973 fpga_events['valveOpen_times'] = np.sort(np.r_[first_on, fpga_events['valveOpen_times']]) 1c
974 fpga_events['valveClose_times'] = np.sort(np.r_[first_off, fpga_events['valveClose_times']]) 1c
975 elif np.any(np.isclose(first_on, self.bpod2fpga(self.bpod_trials['itiIn_times']))): 1cb
976 # Probably assigned to the trial end
977 _logger.debug('Re-reassigning first trial end event. TTL length = %.3g ms', first_off - first_on) 1c
978 fpga_events['itiIn_times'] = np.sort(np.r_[first_on, fpga_events['itiIn_times']]) 1c
979 fpga_events['intervals_1'] = np.sort(np.r_[first_off, fpga_events['intervals_1']]) 1c
980 else:
981 _logger.warning('Unable to reassign first trial start event. TTL length = %.3g ms', first_off - first_on) 1b
982 # Bpod trial_start event intervals are not used but for consistency we'll update them here anyway
983 bpod_event_intervals['trial_start'] = bpod_event_intervals['trial_start'][1:, :] 1cb
984 else:
985 t_trial_start = fpga_events['intervals_0'] 1faegdh
987 out = alfio.AlfBunch() 1faecgbdh
988 # Add the Bpod trial events, converting the timestamp fields to FPGA time.
989 # NB: The trial intervals are by default a Bpod rsync field.
990 out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) 1faecgbdh
991 for k in self.bpod_rsync_fields: 1faecgbdh
992 # Some personal projects may extract non-trials object datasets that may not have 1 event per trial
993 idx = ibpod if self._is_trials_object_attribute(k) else np.arange(len(self.bpod_trials[k]), dtype=int) 1faecgbdh
994 out[k] = self.bpod2fpga(self.bpod_trials[k][idx]) 1faecgbdh
996 f2ttl_t = self.frame2ttl['times'] 1faecgbdh
997 # Assign the FPGA events to individual trials
998 fpga_trials = { 1faecgbdh
999 'goCue_times': _assign_events_to_trial(t_trial_start, fpga_events['goCue_times'], take='first'),
1000 'errorCue_times': _assign_events_to_trial(t_trial_start, fpga_events['errorCue_times']),
1001 'valveOpen_times': _assign_events_to_trial(t_trial_start, fpga_events['valveOpen_times']),
1002 'itiIn_times': _assign_events_to_trial(t_trial_start, fpga_events['itiIn_times']),
1003 'stimOn_times': np.full_like(t_trial_start, np.nan),
1004 'stimOff_times': np.full_like(t_trial_start, np.nan),
1005 'stimFreeze_times': np.full_like(t_trial_start, np.nan)
1006 }
1008 # f2ttl times are unreliable owing to calibration and Bonsai sync square update issues.
1009 # Take the first event after the FPGA aligned stimulus trigger time.
1010 fpga_trials['stimOn_times'] = _assign_events_to_trial( 1faecgbdh
1011 out['stimOnTrigger_times'], f2ttl_t, take='first', t_trial_end=out['stimOffTrigger_times'])
1012 fpga_trials['stimOff_times'] = _assign_events_to_trial( 1faecgbdh
1013 out['stimOffTrigger_times'], f2ttl_t, take='first', t_trial_end=out['intervals'][:, 1])
1014 # For stim freeze we take the last event before the stim off trigger time.
1015 # To avoid assigning early events (e.g. for sessions where there are few flips due to
1016 # mis-calibration), we discount events before stim freeze trigger times (or stim on trigger
1017 # times for versions below 6.2.5). We take the last event rather than the first after stim
1018 # freeze trigger because often there are multiple flips after the trigger, presumably
1019 # before the stim actually stops.
1020 stim_freeze = np.copy(out['stimFreezeTrigger_times']) 1faecgbdh
1021 go_trials = np.where(out['choice'] != 0)[0] 1faecgbdh
1022 # NB: versions below 6.2.5 have no trigger times so use stim on trigger times
1023 lims = np.copy(out['stimOnTrigger_times']) 1faecgbdh
1024 if not np.isnan(stim_freeze).all(): 1faecgbdh
1025 # Stim freeze times are NaN for nogo trials, but for all others use stim freeze trigger
1026 # times. _assign_events_to_trial requires ascending timestamps so no NaNs allowed.
1027 lims[go_trials] = stim_freeze[go_trials] 1febd
1028 # take last event after freeze/stim on trigger, before stim off trigger
1029 stim_freeze = _assign_events_to_trial(lims, f2ttl_t, take='last', t_trial_end=out['stimOffTrigger_times']) 1faecgbdh
1030 fpga_trials['stimFreeze_times'][go_trials] = stim_freeze[go_trials] 1faecgbdh
1031 # Feedback times are valve open on correct trials and error tone in on incorrect trials
1032 fpga_trials['feedback_times'] = np.copy(fpga_trials['valveOpen_times']) 1faecgbdh
1033 ind_err = np.isnan(fpga_trials['valveOpen_times']) 1faecgbdh
1034 fpga_trials['feedback_times'][ind_err] = fpga_trials['errorCue_times'][ind_err] 1faecgbdh
1036 # Use ibpod to discard the final trial if it is incomplete
1037 # ibpod should be indices of all Bpod trials, even those that were not detected on the FPGA
1038 out.update({k: fpga_trials[k][ibpod] for k in fpga_trials.keys()}) 1faecgbdh
1040 if display: # pragma: no cover 1faecgbdh
1041 width = 2
1042 ymax = 5
1043 if isinstance(display, bool):
1044 plt.figure('Bpod FPGA Sync')
1045 ax = plt.gca()
1046 else:
1047 ax = display
1048 plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k')
1049 plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k')
1050 plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k')
1051 color_map = TABLEAU_COLORS.keys()
1052 for (event_name, event_times), c in zip(fpga_events.items(), cycle(color_map)):
1053 plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width)
1054 # Plot the stimulus events along with the trigger times
1055 stim_events = filter(lambda t: 'stim' in t[0], fpga_trials.items())
1056 for (event_name, event_times), c in zip(stim_events, cycle(color_map)):
1057 plots.vertical_lines(
1058 event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width, linestyle='--')
1059 nm = event_name.replace('_times', 'Trigger_times')
1060 plots.vertical_lines(
1061 out[nm], ymin=0, ymax=ymax, ax=ax, color=c, label=nm, linewidth=width, linestyle=':')
1062 ax.legend()
1063 ax.set_yticks([0, 1, 2, 3])
1064 ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio'])
1065 ax.set_ylim([0, 5])
1066 return out 1faecgbdh
1068 def get_wheel_positions(self, *args, **kwargs):
1069 """Extract wheel and wheelMoves objects.
1071 This method is called by the main extract method and may be overloaded by subclasses.
1072 """
1073 return get_wheel_positions(*args, **kwargs) 1faecgbdh
1075 def get_stimulus_update_times(self, sync, chmap, display=False, **_):
1076 """
1077 Extract stimulus update times from sync.
1079 Gets the stimulus times from the frame2ttl channel and cleans the signal.
1081 Parameters
1082 ----------
1083 sync : dict
1084 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
1085 and the corresponding channel numbers.
1086 chmap : dict
1087 A map of channel names and their corresponding indices. Must contain a 'frame2ttl' key.
1088 display : bool
1089 If true, plots the input TTLs and the cleaned output.
1091 Returns
1092 -------
1093 dict
1094 A dictionary with keys {'times', 'polarities'} containing stimulus TTL fronts.
1095 """
1096 frame2ttl = get_sync_fronts(sync, chmap['frame2ttl']) 1fjiaecgbdh
1097 frame2ttl = _clean_frame2ttl(frame2ttl, display=display) 1fjiaecgbdh
1098 return frame2ttl 1fjiaecgbdh
1100 def get_audio_event_times(self, sync, chmap, audio_event_ttls=None, display=False, **_):
1101 """
1102 Extract audio times from sync.
1104 Gets the TTL times from the 'audio' channel, cleans the signal, and classifies each TTL
1105 event by length.
1107 Parameters
1108 ----------
1109 sync : dict
1110 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
1111 and the corresponding channel numbers.
1112 chmap : dict
1113 A map of channel names and their corresponding indices. Must contain an 'audio' key.
1114 audio_event_ttls : dict
1115 A map of event names to (min, max) TTL length.
1116 display : bool
1117 If true, plots the input TTLs and the cleaned output.
1119 Returns
1120 -------
1121 dict
1122 A dictionary with keys {'times', 'polarities'} containing audio TTL fronts.
1123 dict
1124 A dictionary of events (from `audio_event_ttls`) and their intervals as an Nx2 array.
1125 """
1126 audio = get_sync_fronts(sync, chmap['audio']) 1wfjiaecgbdh
1127 audio = _clean_audio(audio) 1wfjiaecgbdh
1129 if audio['times'].size == 0: 1wfjiaecgbdh
1130 _logger.error('No audio sync fronts found.')
1132 if audio_event_ttls is None: 1wfjiaecgbdh
1133 # For training/biased/ephys protocols, the ready tone should be below 110 ms. The error
1134 # tone should be between 400ms and 1200ms
1135 audio_event_ttls = {'ready_tone': (0, 0.1101), 'error_tone': (0.4, 1.2)} 1wfjiaecgbdh
1136 audio_event_intervals = self._assign_events(audio['times'], audio['polarities'], audio_event_ttls, display=display) 1wfjiaecgbdh
1138 return audio, audio_event_intervals 1wfjiaecgbdh
1140 def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs):
1141 """
1142 Extract Bpod times from sync.
1144 Gets the Bpod TTL times from the sync 'bpod' channel and classifies each TTL event by
1145 length. NB: The first trial has an abnormal trial_start TTL that is usually mis-assigned.
1146 This method accounts for this.
1148 Parameters
1149 ----------
1150 sync : dict
1151 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
1152 and the corresponding channel numbers. Must contain a 'bpod' key.
1153 chmap : dict
1154 A map of channel names and their corresponding indices.
1155 bpod_event_ttls : dict of tuple
1156 A map of event names to (min, max) TTL length.
1158 Returns
1159 -------
1160 dict
1161 A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts.
1162 dict
1163 A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array.
1164 """
1165 bpod = get_sync_fronts(sync, chmap['bpod']) 1vtfjaecgbmnolkdh
1166 if bpod.times.size == 0: 1vtfjaecgbmnolkdh
1167 raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. '
1168 'Check channel maps.')
1169 # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these
1170 # lengths are defined by the state machine of the task protocol and therefore vary.
1171 if bpod_event_ttls is None: 1vtfjaecgbmnolkdh
1172 # For training/biased/ephys protocols, the trial start TTL length is 0.1ms but this has
1173 # proven to drift on some Bpods and this is the highest possible value that
1174 # discriminates trial start from valve. Valve open events are between 50ms to 300 ms.
1175 # ITI events are above 400 ms.
1176 bpod_event_ttls = { 1vtfaecgbmnolkdh
1177 'trial_start': (0, 2.33e-4), 'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)}
1178 bpod_event_intervals = self._assign_events( 1vtfjaecgbmnolkdh
1179 bpod['times'], bpod['polarities'], bpod_event_ttls, display=display)
1181 if 'trial_start' not in bpod_event_intervals or bpod_event_intervals['trial_start'].size == 0: 1vtfjaecgbmnolkdh
1182 return bpod, bpod_event_intervals 1j
1184 # The first trial pulse is longer and often assigned to another event.
1185 # Here we move the earliest non-trial_start event to the trial_start array.
1186 t0 = bpod_event_intervals['trial_start'][0, 0] # expect 1st event to be trial_start 1vtfaecgbmnolkdh
1187 pretrial = [(k, v[0, 0]) for k, v in bpod_event_intervals.items() if v.size and v[0, 0] < t0] 1vtfaecgbmnolkdh
1188 if pretrial: 1vtfaecgbmnolkdh
1189 (pretrial, _) = sorted(pretrial, key=lambda x: x[1])[0] # take the earliest event 1vtfaecgbmnolkdh
1190 dt = np.diff(bpod_event_intervals[pretrial][0, :]) * 1e3 # record TTL length to log 1vtfaecgbmnolkdh
1191 _logger.debug('Reassigning first %s to trial_start. TTL length = %.3g ms', pretrial, dt) 1vtfaecgbmnolkdh
1192 bpod_event_intervals['trial_start'] = np.r_[ 1vtfaecgbmnolkdh
1193 bpod_event_intervals[pretrial][0:1, :], bpod_event_intervals['trial_start']
1194 ]
1195 bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :] 1vtfaecgbmnolkdh
1197 return bpod, bpod_event_intervals 1vtfaecgbmnolkdh
1199 @staticmethod
1200 def _assign_events(ts, polarities, event_lengths, precedence='shortest', display=False):
1201 """
1202 Classify TTL events by length.
1204 Outputs the synchronisation events such as trial intervals, valve opening, and audio.
1206 Parameters
1207 ----------
1208 ts : numpy.array
1209 Numpy vector containing times of TTL fronts.
1210 polarities : numpy.array
1211 Numpy vector containing polarity of TTL fronts (1 rise, -1 fall).
1212 event_lengths : dict of tuple
1213 A map of TTL events and the range of permissible lengths, where l0 < ttl <= l1.
1214 precedence : str {'shortest', 'longest', 'dict order'}
1215 In the case of overlapping event TTL lengths, assign shortest/longest first or go by
1216 the `event_lengths` dict order.
1217 display : bool
1218 If true, plots the TTLs with coloured lines delineating the assigned events.
1220 Returns
1221 -------
1222 Dict[str, numpy.array]
1223 A dictionary of events and their intervals as an Nx2 array.
1225 See Also
1226 --------
1227 _assign_events_to_trial - classify TTLs by event order within a given trial period.
1228 """
1229 event_intervals = dict.fromkeys(event_lengths) 1wvtfjiaecgbmnolkdh
1230 assert 'unassigned' not in event_lengths.keys() 1wvtfjiaecgbmnolkdh
1232 if len(ts) == 0: 1wvtfjiaecgbmnolkdh
1233 return {k: np.array([[], []]).T for k in (*event_lengths.keys(), 'unassigned')}
1235 # make sure that there are no 2 consecutive fall or consecutive rise events
1236 assert np.all(np.abs(np.diff(polarities)) == 2) 1wvtfjiaecgbmnolkdh
1237 if polarities[0] == -1: 1wvtfjiaecgbmnolkdh
1238 ts = np.delete(ts, 0) 1wt
1239 if polarities[-1] == 1: # if the final TTL is left HIGH, insert a NaN 1wvtfjiaecgbmnolkdh
1240 ts = np.r_[ts, np.nan] 1t
1241 # take only even time differences: i.e. from rising to falling fronts
1242 dt = np.diff(ts)[::2] 1wvtfjiaecgbmnolkdh
1244 # Assign events from shortest TTL to largest
1245 assigned = np.zeros(ts.shape, dtype=bool) 1wvtfjiaecgbmnolkdh
1246 if precedence.lower() == 'shortest': 1wvtfjiaecgbmnolkdh
1247 event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1])) 1wvtfjiaecgbmnolkdh
1248 elif precedence.lower() == 'longest':
1249 event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1]), reverse=True)
1250 elif precedence.lower() == 'dict order':
1251 event_items = event_lengths.items()
1252 else:
1253 raise ValueError(f'Precedence must be one of "shortest", "longest", "dict order", got "{precedence}".')
1254 for event, (min_len, max_len) in event_items: 1wvtfjiaecgbmnolkdh
1255 _logger.debug('%s: %.4G < ttl <= %.4G', event, min_len, max_len) 1wvtfjiaecgbmnolkdh
1256 i_event = np.where(np.logical_and(dt > min_len, dt <= max_len))[0] * 2 1wvtfjiaecgbmnolkdh
1257 i_event = i_event[np.where(~assigned[i_event])[0]] # remove those already assigned 1wvtfjiaecgbmnolkdh
1258 event_intervals[event] = np.c_[ts[i_event], ts[i_event + 1]] 1wvtfjiaecgbmnolkdh
1259 assigned[np.r_[i_event, i_event + 1]] = True 1wvtfjiaecgbmnolkdh
1261 # Include the unassigned events for convenience and debugging
1262 event_intervals['unassigned'] = ts[~assigned].reshape(-1, 2) 1wvtfjiaecgbmnolkdh
1264 # Assert that event TTLs mutually exclusive
1265 all_assigned = np.concatenate(list(event_intervals.values())).flatten() 1wvtfjiaecgbmnolkdh
1266 assert all_assigned.size == np.unique(all_assigned).size, 'TTLs assigned to multiple events' 1wvtfjiaecgbmnolkdh
1268 # some debug plots when needed
1269 if display: # pragma: no cover 1wvtfjiaecgbmnolkdh
1270 plt.figure()
1271 plots.squares(ts, polarities, label='raw fronts')
1272 for event, intervals in event_intervals.items():
1273 plots.vertical_lines(intervals[:, 0], ymin=-0.2, ymax=1.1, linewidth=0.5, label=event)
1274 plt.legend()
1276 # Return map of event intervals in the same order as `event_lengths` dict
1277 return {k: event_intervals[k] for k in (*event_lengths, 'unassigned')} 1wvtfjiaecgbmnolkdh
1279 @staticmethod
1280 def sync_bpod_clock(bpod_trials, fpga_trials, sync_field):
1281 """
1282 Sync the Bpod clock to FPGA one using the provided trial event.
1284 It assumes that `sync_field` is in both `fpga_trials` and `bpod_trials`. Syncing on both
1285 intervals is not supported so to sync on trial start times, `sync_field` should be
1286 'intervals_0'.
1288 Parameters
1289 ----------
1290 bpod_trials : dict
1291 A dictionary of extracted Bpod trial events.
1292 fpga_trials : dict
1293 A dictionary of TTL events extracted from FPGA sync (see `extract_behaviour_sync`
1294 method).
1295 sync_field : str
1296 The trials key to use for syncing clocks. For intervals (i.e. Nx2 arrays) append the
1297 column index, e.g. 'intervals_0'.
1299 Returns
1300 -------
1301 function
1302 Interpolation function such that f(timestamps_bpod) = timestamps_fpga.
1303 float
1304 The clock drift in parts per million.
1305 numpy.array of int
1306 The indices of the Bpod trial events in the FPGA trial events array.
1307 numpy.array of int
1308 The indices of the FPGA trial events in the Bpod trial events array.
1310 Raises
1311 ------
1312 ValueError
1313 The key `sync_field` was not found in either the `bpod_trials` or `fpga_trials` dicts.
1314 """
1315 _logger.info(f'Attempting to align Bpod clock to DAQ using trial event "{sync_field}"') 1fjiaecgbdh
1316 bpod_fpga_timestamps = [None, None] 1fjiaecgbdh
1317 for i, trials in enumerate((bpod_trials, fpga_trials)): 1fjiaecgbdh
1318 if sync_field not in trials: 1fjiaecgbdh
1319 # handle syncing on intervals
1320 if not (m := re.match(r'(.*)_(\d)', sync_field)): 1fiaecgbdh
1321 # If missing from bpod trials, either the sync field is incorrect,
1322 # or the Bpod extractor is incorrect. If missing from the fpga events, check
1323 # the sync field and the `extract_behaviour_sync` method.
1324 raise ValueError(
1325 f'Sync field "{sync_field}" not in extracted {"fpga" if i else "bpod"} events')
1326 _sync_field, n = m.groups() 1fiaecgbdh
1327 bpod_fpga_timestamps[i] = trials[_sync_field][:, int(n)] 1fiaecgbdh
1328 else:
1329 bpod_fpga_timestamps[i] = trials[sync_field] 1fjiaecgbdh
1331 # Sync the two timestamps
1332 fcn, drift, ibpod, ifpga = ibldsp.utils.sync_timestamps(*bpod_fpga_timestamps, return_indices=True) 1fjiaecgbdh
1334 # If it's drifting too much throw warning or error
1335 _logger.info('N trials: %i bpod, %i FPGA, %i merged, sync %.5f ppm', 1fjiaecgbdh
1336 *map(len, bpod_fpga_timestamps), len(ibpod), drift)
1337 if drift > 200 and bpod_fpga_timestamps[0].size != bpod_fpga_timestamps[1].size: 1fjiaecgbdh
1338 raise err.SyncBpodFpgaException('sync cluster f*ck')
1339 elif drift > BPOD_FPGA_DRIFT_THRESHOLD_PPM: 1fjiaecgbdh
1340 _logger.warning('BPOD/FPGA synchronization shows values greater than %.2f ppm',
1341 BPOD_FPGA_DRIFT_THRESHOLD_PPM)
1343 return fcn, drift, ibpod, ifpga 1fjiaecgbdh
1346class FpgaTrialsHabituation(FpgaTrials):
1347 """Extract habituationChoiceWorld trial events from an NI DAQ."""
1349 save_names = ('_ibl_trials.stimCenter_times.npy', '_ibl_trials.feedbackType.npy', '_ibl_trials.rewardVolume.npy',
1350 '_ibl_trials.stimOff_times.npy', '_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy',
1351 '_ibl_trials.feedback_times.npy', '_ibl_trials.stimOn_times.npy', '_ibl_trials.stimOnTrigger_times.npy',
1352 '_ibl_trials.intervals.npy', '_ibl_trials.goCue_times.npy', '_ibl_trials.goCueTrigger_times.npy',
1353 None, None, None, None, None)
1354 """tuple of str: The filenames of each extracted dataset, or None if array should not be saved."""
1356 var_names = ('stimCenter_times', 'feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft',
1357 'contrastRight', 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals',
1358 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
1359 'stimCenterTrigger_times', 'position', 'phase')
1360 """tuple of str: A list of names for the extracted variables. These become the returned output keys."""
1362 bpod_rsync_fields = ('intervals', 'stimOn_times', 'feedback_times', 'stimCenterTrigger_times',
1363 'goCue_times', 'itiIn_times', 'stimOffTrigger_times', 'stimOff_times',
1364 'stimCenter_times', 'stimOnTrigger_times', 'goCueTrigger_times')
1365 """tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA."""
1367 bpod_fields = ('feedbackType', 'rewardVolume', 'contrastLeft', 'contrastRight', 'position', 'phase')
1368 """tuple of str: Fields from Bpod extractor that we want to save."""
1370 sync_field = 'feedback_times' # valve open events
1371 """str: The trial event to synchronize (must be present in extracted trials)."""
1373 def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data',
1374 task_collection='raw_behavior_data', **kwargs) -> dict:
1375 """
1376 Extract habituationChoiceWorld trial events from an NI DAQ.
1378 It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field`
1379 attributes are all correct for the bpod protocol used.
1381 Unlike FpgaTrials, this class assumes different Bpod TTL events and syncs the Bpod clock
1382 using the valve open times, instead of the trial start times.
1384 Parameters
1385 ----------
1386 sync : dict
1387 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
1388 and the corresponding channel numbers. If None, the sync is loaded using the
1389 `load_sync` method.
1390 dict
1391 A map of channel names and their corresponding indices. If None, the channel map is
1392 loaded using the `load_sync` method.
1393 sync_collection : str
1394 The session subdirectory where the sync data are located. This is only used if the
1395 sync or channel maps are not provided.
1396 task_collection : str
1397 The session subdirectory where the raw Bpod data are located. This is used for loading
1398 the task settings and extracting the bpod trials, if not already done.
1399 protocol_number : int
1400 The protocol number if multiple protocols were run during the session. If provided, a
1401 spacer signal must be present in order to determine the correct period.
1402 kwargs
1403 Optional arguments for class methods, e.g. 'display', 'bpod_event_ttls'.
1405 Returns
1406 -------
1407 dict
1408 A dictionary of numpy arrays with `FpgaTrialsHabituation.var_names` as keys.
1409 """
1410 # Version check: the ITI in TTL was added in a later version
1411 if not self.settings: 1ia
1412 self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection)
1413 iblrig_version = version.parse(self.settings.get('IBL_VERSION', '0.0.0')) 1ia
1414 if version.parse('8.9.3') <= iblrig_version < version.parse('8.12.6'): 1ia
1415 """A second 1s TTL was added in this version during the 'iti' state, however this is
1416 unrelated to the trial ITI and is unfortunately the same length as the trial start TTL."""
1417 raise NotImplementedError('Ambiguous TTLs in 8.9.3 >= version < 8.12.6')
1419 trials = super()._extract(sync=sync, chmap=chmap, sync_collection=sync_collection, 1ia
1420 task_collection=task_collection, **kwargs)
1422 return trials 1ia
1424 def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs):
1425 """
1426 Extract Bpod times from sync.
1428 Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse.
1429 Also the first trial pulse is incorrectly assigned due to its abnormal length.
1431 Parameters
1432 ----------
1433 sync : dict
1434 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
1435 and the corresponding channel numbers. Must contain a 'bpod' key.
1436 chmap : dict
1437 A map of channel names and their corresponding indices.
1438 bpod_event_ttls : dict of tuple
1439 A map of event names to (min, max) TTL length.
1441 Returns
1442 -------
1443 dict
1444 A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts.
1445 dict
1446 A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array.
1447 """
1448 bpod = get_sync_fronts(sync, chmap['bpod']) 1ia
1449 if bpod.times.size == 0: 1ia
1450 raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. '
1451 'Check channel maps.')
1452 # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these
1453 # lengths are defined by the state machine of the task protocol and therefore vary.
1454 if bpod_event_ttls is None: 1ia
1455 # Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse
1456 bpod_event_ttls = {'trial_iti': (.999, 1.1), 'valve_open': (0, 0.4)} 1ia
1457 bpod_event_intervals = self._assign_events( 1ia
1458 bpod['times'], bpod['polarities'], bpod_event_ttls, display=display)
1460 # The first trial pulse is shorter and assigned to valve_open. Here we remove the first
1461 # valve event, prepend a 0 to the trial_start events, and drop the last trial if it was
1462 # incomplete in Bpod.
1463 t0 = bpod_event_intervals['trial_iti'][0, 0] # expect 1st event to be trial_start 1ia
1464 pretrial = [(k, v[0, 0]) for k, v in bpod_event_intervals.items() if v.size and v[0, 0] < t0] 1ia
1465 if pretrial: 1ia
1466 (pretrial, _) = sorted(pretrial, key=lambda x: x[1])[0] # take the earliest event 1a
1467 dt = np.diff(bpod_event_intervals[pretrial][0, :]) * 1e3 # record TTL length to log 1a
1468 _logger.debug('Reassigning first %s to trial_start. TTL length = %.3g ms', pretrial, dt) 1a
1469 bpod_event_intervals['trial_iti'] = np.r_[ 1a
1470 bpod_event_intervals[pretrial][0:1, :], bpod_event_intervals['trial_iti']
1471 ]
1472 bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :] 1a
1474 return bpod, bpod_event_intervals 1ia
1476 def build_trials(self, sync, chmap, display=False, **kwargs):
1477 """
1478 Extract task related event times from the sync.
1480 This is called by the superclass `_extract` method. The key difference here is that the
1481 `trial_start` LOW->HIGH is the trial end, and HIGH->LOW is trial start.
1483 Parameters
1484 ----------
1485 sync : dict
1486 'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
1487 chmap : dict
1488 Map of channel names and their corresponding index. Default to constant.
1489 display : bool, matplotlib.pyplot.Axes
1490 Show the full session sync pulses display.
1492 Returns
1493 -------
1494 dict
1495 A map of trial event timestamps.
1496 """
1497 # Get the events from the sync.
1498 # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC
1499 self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) 1ia
1500 self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) 1ia
1501 self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) 1ia
1502 if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_iti'}: 1ia
1503 raise ValueError(
1504 'Expected at least "trial_iti" and "valve_open" Bpod events. `bpod_event_ttls` kwarg may be incorrect.')
1506 fpga_events = alfio.AlfBunch({ 1ia
1507 'feedback_times': bpod_event_intervals['valve_open'][:, 0],
1508 'valveClose_times': bpod_event_intervals['valve_open'][:, 1],
1509 'intervals_0': bpod_event_intervals['trial_iti'][:, 1],
1510 'intervals_1': bpod_event_intervals['trial_iti'][:, 0],
1511 'goCue_times': audio_event_intervals['ready_tone'][:, 0]
1512 })
1514 # Sync the Bpod clock to the DAQ.
1515 self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) 1ia
1517 out = alfio.AlfBunch() 1ia
1518 # Add the Bpod trial events, converting the timestamp fields to FPGA time.
1519 # NB: The trial intervals are by default a Bpod rsync field.
1520 out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) 1ia
1521 out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) 1ia
1523 # Assigning each event to a trial ensures exactly one event per trial (missing events are NaN)
1524 trials = alfio.AlfBunch({ 1ia
1525 'goCue_times': _assign_events_to_trial(out['goCueTrigger_times'], fpga_events['goCue_times'], take='first'),
1526 'feedback_times': _assign_events_to_trial(fpga_events['intervals_0'], fpga_events['feedback_times']),
1527 'stimCenter_times': _assign_events_to_trial(
1528 out['stimCenterTrigger_times'], self.frame2ttl['times'], take='first', t_trial_end=out['stimOffTrigger_times']),
1529 'stimOn_times': _assign_events_to_trial(
1530 out['stimOnTrigger_times'], self.frame2ttl['times'], take='first', t_trial_end=out['stimCenterTrigger_times']),
1531 'stimOff_times': _assign_events_to_trial(
1532 out['stimOffTrigger_times'], self.frame2ttl['times'],
1533 take='first', t_trial_end=np.r_[out['intervals'][1:, 0], np.inf])
1534 })
1535 out.update({k: trials[k][ifpga] for k in trials.keys()}) 1ia
1537 # If stim on occurs before trial end, use stim on time. Likewise for trial end and stim off
1538 to_correct = ~np.isnan(out['stimOn_times']) & (out['stimOn_times'] < out['intervals'][:, 0]) 1ia
1539 if np.any(to_correct): 1ia
1540 _logger.warning('%i/%i stim on events occurring outside trial intervals', sum(to_correct), len(to_correct))
1541 out['intervals'][to_correct, 0] = out['stimOn_times'][to_correct]
1542 to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1]) 1ia
1543 if np.any(to_correct): 1ia
1544 _logger.debug( 1ia
1545 '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end',
1546 sum(to_correct), len(to_correct))
1547 out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct] 1ia
1549 if display: # pragma: no cover 1ia
1550 width = 0.5
1551 ymax = 5
1552 if isinstance(display, bool):
1553 plt.figure('Bpod FPGA Sync')
1554 ax = plt.gca()
1555 else:
1556 ax = display
1557 plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k')
1558 plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k')
1559 plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k')
1560 color_map = TABLEAU_COLORS.keys()
1561 for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)):
1562 plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width)
1563 ax.legend()
1564 ax.set_yticks([0, 1, 2, 3])
1565 ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio'])
1566 ax.set_ylim([0, 4])
1568 return out 1ia
1571def get_sync_and_chn_map(session_path, sync_collection):
1572 """
1573 Return sync and channel map for session based on collection where main sync is stored.
1575 Parameters
1576 ----------
1577 session_path : str, pathlib.Path
1578 The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
1579 sync_collection : str
1580 The session subdirectory where the sync data are located.
1582 Returns
1583 -------
1584 one.alf.io.AlfBunch
1585 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and
1586 the corresponding channel numbers.
1587 dict
1588 A map of channel names and their corresponding indices.
1589 """
1590 if sync_collection == 'raw_ephys_data': 1FGfNRTUVOPQuHIaecgbdh
1591 # Check to see if we have nidq files, if we do just go with this otherwise go into other function that deals with
1592 # 3A probes
1593 nidq_meta = next(session_path.joinpath(sync_collection).glob('*nidq.meta'), None) 1FGfNOPQuHIaecgbdh
1594 if not nidq_meta: 1FGfNOPQuHIaecgbdh
1595 sync, chmap = get_main_probe_sync(session_path) 1aecd
1596 else:
1597 sync = load_sync(session_path, sync_collection) 1FGfNOPQuHIagbdh
1598 ef = Bunch() 1FGfNOPQuHIagbdh
1599 ef['path'] = session_path.joinpath(sync_collection) 1FGfNOPQuHIagbdh
1600 ef['nidq'] = nidq_meta 1FGfNOPQuHIagbdh
1601 chmap = get_ibl_sync_map(ef, '3B') 1FGfNOPQuHIagbdh
1603 else:
1604 sync = load_sync(session_path, sync_collection) 1RTUV
1605 chmap = load_channel_map(session_path, sync_collection) 1RTUV
1607 return sync, chmap 1FGfNRTUVOPQuHIaecgbdh
1610def load_channel_map(session_path, sync_collection):
1611 """
1612 Load syncing channel map for session path and collection
1614 Parameters
1615 ----------
1616 session_path : str, pathlib.Path
1617 The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
1618 sync_collection : str
1619 The session subdirectory where the sync data are located.
1621 Returns
1622 -------
1623 dict
1624 A map of channel names and their corresponding indices.
1625 """
1627 device = sync_collection.split('_')[1] 1RTUV
1628 default_chmap = DEFAULT_MAPS[device]['nidq'] 1RTUV
1630 # Try to load channel map from file
1631 chmap = spikeglx.get_sync_map(session_path.joinpath(sync_collection)) 1RTUV
1632 # If chmap provided but not with all keys, fill up with default values
1633 if not chmap: 1RTUV
1634 return default_chmap 1TUV
1635 else:
1636 if data_for_keys(default_chmap.keys(), chmap): 1R
1637 return chmap 1R
1638 else:
1639 _logger.warning('Keys missing from provided channel map, '
1640 'setting missing keys from default channel map')
1641 return {**default_chmap, **chmap}
1644def load_sync(session_path, sync_collection):
1645 """
1646 Load sync files from session path and collection.
1648 Parameters
1649 ----------
1650 session_path : str, pathlib.Path
1651 The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
1652 sync_collection : str
1653 The session subdirectory where the sync data are located.
1655 Returns
1656 -------
1657 one.alf.io.AlfBunch
1658 A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and
1659 the corresponding channel numbers.
1660 """
1661 sync = alfio.load_object(session_path.joinpath(sync_collection), 'sync', namespace='spikeglx', short_keys=True) 1FGfNRTUVOPQuHIagbdh
1663 return sync 1FGfNRTUVOPQuHIagbdh