Coverage for ibllib/io/extractors/fibrephotometry.py: 94%
109 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1"""Data extraction from fibrephotometry DAQ files.
3Below is the expected folder structure for a fibrephotometry session:
5 subject/
6 ├─ 2021-06-30/
7 │ ├─ 001/
8 │ │ ├─ raw_photometry_data/
9 │ │ │ │ ├─ _neurophotometrics_fpData.raw.pqt
10 │ │ │ │ ├─ _neurophotometrics_fpData.channels.csv
11 │ │ │ │ ├─ _mcc_DAQdata.raw.tdms
13fpData.raw.pqt is a copy of the 'FPdata' file, the output of the Neuophotometrics Bonsai workflow.
14fpData.channels.csv is table of frame flags for deciphering LED and GPIO states. The default table,
15copied from the Neurophotometrics manual can be found in iblscripts/deploy/fppc/
16_mcc_DAQdata.raw.tdms is the DAQ tdms file, containing the pulses from bpod and from the neurophotometrics system
17"""
18import logging
20import pandas as pd
21import numpy as np
22import scipy.interpolate
24import one.alf.io as alfio
25from ibllib.io.extractors.base import BaseExtractor
26from ibllib.io.raw_daq_loaders import load_channels_tdms, load_raw_daq_tdms
27from ibllib.io.extractors.training_trials import GoCueTriggerTimes
28from ibldsp.utils import rises, sync_timestamps
30_logger = logging.getLogger(__name__)
32DAQ_CHMAP = {"photometry": 'AI0', 'bpod': 'AI1'}
33V_THRESHOLD = 3
35"""
36Neurophotometrics FP3002 specific information.
37The light source map refers to the available LEDs on the system.
38The flags refers to the byte encoding of led states in the system.
39"""
40LIGHT_SOURCE_MAP = {
41 'color': ['None', 'Violet', 'Blue', 'Green'],
42 'wavelength': [0, 415, 470, 560],
43 'name': ['None', 'Isosbestic', 'GCaMP', 'RCaMP'],
44}
46NEUROPHOTOMETRICS_LED_STATES = {
47 'Condition': {
48 0: 'No additional signal',
49 1: 'Output 1 signal HIGH',
50 2: 'Output 0 signal HIGH',
51 3: 'Stimulation ON',
52 4: 'GPIO Line 2 HIGH',
53 5: 'GPIO Line 3 HIGH',
54 6: 'Input 1 HIGH',
55 7: 'Input 0 HIGH',
56 8: 'Output 0 signal HIGH + Stimulation',
57 9: 'Output 0 signal HIGH + Input 0 signal HIGH',
58 10: 'Input 0 signal HIGH + Stimulation',
59 11: 'Output 0 HIGH + Input 0 HIGH + Stimulation',
60 },
61 'No LED ON': {0: 0, 1: 8, 2: 16, 3: 32, 4: 64, 5: 128, 6: 256, 7: 512, 8: 48, 9: 528, 10: 544, 11: 560},
62 'L415': {0: 1, 1: 9, 2: 17, 3: 33, 4: 65, 5: 129, 6: 257, 7: 513, 8: 49, 9: 529, 10: 545, 11: 561},
63 'L470': {0: 2, 1: 10, 2: 18, 3: 34, 4: 66, 5: 130, 6: 258, 7: 514, 8: 50, 9: 530, 10: 546, 11: 562},
64 'L560': {0: 4, 1: 12, 2: 20, 3: 36, 4: 68, 5: 132, 6: 260, 7: 516, 8: 52, 9: 532, 10: 548, 11: 564}
65}
68def sync_photometry_to_daq(vdaq, fs, df_photometry, chmap=DAQ_CHMAP, v_threshold=V_THRESHOLD):
69 """
70 :param vdaq: dictionary of daq traces.
71 :param fs: sampling frequency
72 :param df_photometry:
73 :param chmap:
74 :param v_threshold:
75 :return:
76 """
77 # here we take the flag that is the most common
78 daq_frames, tag_daq_frames = read_daq_timestamps(vdaq=vdaq, v_threshold=v_threshold) 1a
79 nf = np.minimum(tag_daq_frames.size, df_photometry['Input0'].size) 1a
81 # we compute the framecounter for the DAQ, and match the bpod up state frame by frame for different shifts
82 # the shift that minimizes the mismatch is usually good
83 df = np.median(np.diff(df_photometry['Timestamp'])) 1a
84 fc = np.cumsum(np.round(np.diff(daq_frames) / fs / df).astype(np.int32)) - 1 # this is a daq frame counter 1a
85 fc = fc[fc < (nf - 1)] 1a
86 max_shift = 300 1a
87 error = np.zeros(max_shift * 2 + 1) 1a
88 shifts = np.arange(-max_shift, max_shift + 1) 1a
89 for i, shift in enumerate(shifts): 1a
90 rolled_fp = np.roll(df_photometry['Input0'].values[fc], shift) 1a
91 error[i] = np.sum(np.abs(rolled_fp - tag_daq_frames[:fc.size])) 1a
92 # a negative shift means that the DAQ is ahead of the photometry and that the DAQ misses frame at the beginning
93 frame_shift = shifts[np.argmax(-error)] 1a
94 if np.sign(frame_shift) == -1: 1a
95 ifp = fc[np.abs(frame_shift):]
96 elif np.sign(frame_shift) == 0: 1a
97 ifp = fc 1a
98 elif np.sign(frame_shift) == 1:
99 ifp = fc[:-np.abs(frame_shift)]
100 t_photometry = df_photometry['Timestamp'].values[ifp] 1a
101 t_daq = daq_frames[:ifp.size] / fs 1a
102 # import matplotlib.pyplot as plt
103 # plt.plot(shifts, -error)
104 fcn_fp2daq = scipy.interpolate.interp1d(t_photometry, t_daq, fill_value='extrapolate') 1a
105 drift_ppm = (np.polyfit(t_daq, t_photometry, 1)[0] - 1) * 1e6 1a
106 if drift_ppm > 120: 1a
107 _logger.warning(f"drift photometry to DAQ PPM: {drift_ppm}")
108 else:
109 _logger.info(f"drift photometry to DAQ PPM: {drift_ppm}") 1a
110 # here is a bunch of safeguards
111 assert np.unique(np.diff(df_photometry['FrameCounter'])).size == 1 # checks that there are no missed frames on photo 1a
112 assert np.abs(frame_shift) <= 5 # it's always the end frames that are missing 1a
113 assert np.abs(drift_ppm) < 60 1a
114 ts_daq = fcn_fp2daq(df_photometry['Timestamp'].values) # those are the timestamps in daq time 1a
115 return ts_daq, fcn_fp2daq, drift_ppm 1a
118def read_daq_voltage(daq_file, chmap=DAQ_CHMAP):
119 channel_names = [c.name for c in load_raw_daq_tdms(daq_file)['Analog'].channels()] 1da
120 assert all([v in channel_names for v in chmap.values()]), "Missing channel" 1da
121 vdaq, fs = load_channels_tdms(daq_file, chmap=chmap) 1da
122 vdaq = {k: v - np.median(v) for k, v in vdaq.items()} 1da
123 return vdaq, fs 1da
126def read_daq_timestamps(vdaq, v_threshold=V_THRESHOLD):
127 """
128 From a tdms daq file, extracts the photometry frames and their tagging.
129 :param vsaq: dictionary of the voltage traces from the DAQ. Each item has a key describing
130 the channel as per the channel map, and contains a single voltage trace.
131 :param v_threshold:
132 :return:
133 """
134 daq_frames = rises(vdaq['photometry'], step=v_threshold, analog=True) 1da
135 if daq_frames.size == 0: 1da
136 daq_frames = rises(-vdaq['photometry'], step=v_threshold, analog=True)
137 _logger.warning(f'No photometry pulses detected, attempting to reverse voltage and detect again,'
138 f'found {daq_frames.size} in reverse voltage. CHECK YOUR FP WIRING TO THE DAQ !!')
139 tagged_frames = vdaq['bpod'][daq_frames] > v_threshold 1da
140 return daq_frames, tagged_frames 1da
143def check_timestamps(daq_file, photometry_file, tolerance=20, chmap=DAQ_CHMAP, v_threshold=V_THRESHOLD):
144 """
145 Reads data file and checks that the number of timestamps check out with a tolerance of n_frames
146 :param daq_file:
147 :param photometry_file:
148 :param tolerance: number of acceptable missing frames between the daq and the photometry file
149 :param chmap:
150 :param v_threshold:
151 :return: None
152 """
153 df_photometry = pd.read_csv(photometry_file) 1d
154 v, fs = read_daq_voltage(daq_file=daq_file, chmap=chmap) 1d
155 daq_frames, _ = read_daq_timestamps(vdaq=v, v_threshold=v_threshold) 1d
156 assert (daq_frames.shape[0] - df_photometry.shape[0]) < tolerance 1d
157 _logger.info(f"{daq_frames.shape[0] - df_photometry.shape[0]} frames difference, " 1d
158 f"{'/'.join(daq_file.parts[-2:])}: {daq_frames.shape[0]} frames, "
159 f"{'/'.join(photometry_file.parts[-2:])}: {df_photometry.shape[0]}")
162class FibrePhotometry(BaseExtractor):
163 """
164 FibrePhotometry(self.session_path, collection=self.collection)
165 """
166 save_names = ('photometry.signal.pqt')
167 var_names = ('df_out')
169 def __init__(self, *args, collection='raw_photometry_data', **kwargs):
170 """An extractor for all Neurophotometrics fibrephotometry data"""
171 self.collection = collection 1ca
172 super().__init__(*args, **kwargs) 1ca
174 @staticmethod
175 def _channel_meta(light_source_map=None):
176 """
177 Return table of light source wavelengths and corresponding colour labels.
179 Parameters
180 ----------
181 light_source_map : dict
182 An optional map of light source wavelengths (nm) used and their corresponding colour name.
184 Returns
185 -------
186 pandas.DataFrame
187 A sorted table of wavelength and colour name.
188 """
189 light_source_map = light_source_map or LIGHT_SOURCE_MAP 1ca
190 meta = pd.DataFrame.from_dict(light_source_map) 1ca
191 meta.index.rename('channel_id', inplace=True) 1ca
192 return meta 1ca
194 def _extract(self, light_source_map=None, collection=None, regions=None, **kwargs):
195 """
197 Parameters
198 ----------
199 regions: list of str
200 The list of regions to extract. If None extracts all columns containing "Region". Defaults to None.
201 light_source_map : dict
202 An optional map of light source wavelengths (nm) used and their corresponding colour name.
203 collection: str / pathlib.Path
204 An optional relative path from the session root folder to find the raw photometry data.
205 Defaults to `raw_photometry_data`
207 Returns
208 -------
209 numpy.ndarray
210 A 1D array of signal values.
211 numpy.ndarray
212 A 1D array of ints corresponding to the active light source during a given frame.
213 pandas.DataFrame
214 A table of intensity for each region, with associated times, wavelengths, names and colors
215 """
216 collection = collection or self.collection 1ca
217 fp_data = alfio.load_object(self.session_path / collection, 'fpData') 1ca
218 ts = self.extract_timestamps(fp_data['raw'], **kwargs) 1ca
220 # Load channels and
221 channel_meta_map = self._channel_meta(kwargs.get('light_source_map')) 1ca
222 led_states = fp_data.get('channels', pd.DataFrame(NEUROPHOTOMETRICS_LED_STATES)) 1ca
223 led_states = led_states.set_index('Condition') 1ca
224 # Extract signal columns into 2D array
225 regions = regions or [k for k in fp_data['raw'].keys() if 'Region' in k] 1ca
226 out_df = fp_data['raw'].filter(items=regions, axis=1).sort_index(axis=1) 1ca
227 out_df['times'] = ts 1ca
228 out_df['wavelength'] = np.NaN 1ca
229 out_df['name'] = '' 1ca
230 out_df['color'] = '' 1ca
231 # Extract channel index
232 states = fp_data['raw'].get('LedState', fp_data['raw'].get('Flags', None)) 1ca
233 for state in states.unique(): 1ca
234 ir, ic = np.where(led_states == state) 1ca
235 if ic.size == 0: 1ca
236 continue 1a
237 for cn in ['name', 'color', 'wavelength']: 1ca
238 out_df.loc[states == state, cn] = channel_meta_map.iloc[ic[0]][cn] 1ca
239 return out_df 1ca
241 def extract_timestamps(self, fp_data, **kwargs):
242 """Extract the photometry.timestamps array.
244 This depends on the DAQ and task synchronization protocol.
246 Parameters
247 ----------
248 fp_data : dict
249 A Bunch of raw fibrephotometry data, with the keys ('raw', 'channels').
251 Returns
252 -------
253 numpy.ndarray
254 An array of timestamps, one per frame.
255 """
256 daq_file = next(self.session_path.joinpath(self.collection).glob('*.tdms')) 1a
257 vdaq, fs = read_daq_voltage(daq_file, chmap=DAQ_CHMAP) 1a
258 ts, fcn_daq2_, drift_ppm = sync_photometry_to_daq( 1a
259 vdaq=vdaq, fs=fs, df_photometry=fp_data, v_threshold=V_THRESHOLD)
260 gc_bpod, _ = GoCueTriggerTimes(session_path=self.session_path).extract(task_collection='raw_behavior_data', save=False) 1a
261 gc_daq = rises(vdaq['bpod']) 1a
263 fcn_daq2_bpod, drift_ppm, idaq, ibp = sync_timestamps( 1a
264 rises(vdaq['bpod']) / fs, gc_bpod, return_indices=True)
265 assert drift_ppm < 100, f"Drift between bpod and daq is above 100 ppm: {drift_ppm}" 1a
266 assert (gc_daq.size - idaq.size) < 5, "Bpod and daq synchronisation failed as too few" \ 1a
267 "events could be matched"
268 ts = fcn_daq2_bpod(ts) 1a
269 return ts 1a