Coverage for ibllib/io/extractors/fibrephotometry.py: 94%

109 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""Data extraction from fibrephotometry DAQ files. 

2 

3Below is the expected folder structure for a fibrephotometry session: 

4 

5 subject/ 

6 ├─ 2021-06-30/ 

7 │ ├─ 001/ 

8 │ │ ├─ raw_photometry_data/ 

9 │ │ │ │ ├─ _neurophotometrics_fpData.raw.pqt 

10 │ │ │ │ ├─ _neurophotometrics_fpData.channels.csv 

11 │ │ │ │ ├─ _mcc_DAQdata.raw.tdms 

12 

13fpData.raw.pqt is a copy of the 'FPdata' file, the output of the Neuophotometrics Bonsai workflow. 

14fpData.channels.csv is table of frame flags for deciphering LED and GPIO states. The default table, 

15copied from the Neurophotometrics manual can be found in iblscripts/deploy/fppc/ 

16_mcc_DAQdata.raw.tdms is the DAQ tdms file, containing the pulses from bpod and from the neurophotometrics system 

17""" 

18import logging 

19 

20import pandas as pd 

21import numpy as np 

22import scipy.interpolate 

23 

24import one.alf.io as alfio 

25from ibllib.io.extractors.base import BaseExtractor 

26from ibllib.io.raw_daq_loaders import load_channels_tdms, load_raw_daq_tdms 

27from ibllib.io.extractors.training_trials import GoCueTriggerTimes 

28from ibldsp.utils import rises, sync_timestamps 

29 

30_logger = logging.getLogger(__name__) 

31 

32DAQ_CHMAP = {"photometry": 'AI0', 'bpod': 'AI1'} 

33V_THRESHOLD = 3 

34 

35""" 

36Neurophotometrics FP3002 specific information. 

37The light source map refers to the available LEDs on the system. 

38The flags refers to the byte encoding of led states in the system. 

39""" 

40LIGHT_SOURCE_MAP = { 

41 'color': ['None', 'Violet', 'Blue', 'Green'], 

42 'wavelength': [0, 415, 470, 560], 

43 'name': ['None', 'Isosbestic', 'GCaMP', 'RCaMP'], 

44} 

45 

46NEUROPHOTOMETRICS_LED_STATES = { 

47 'Condition': { 

48 0: 'No additional signal', 

49 1: 'Output 1 signal HIGH', 

50 2: 'Output 0 signal HIGH', 

51 3: 'Stimulation ON', 

52 4: 'GPIO Line 2 HIGH', 

53 5: 'GPIO Line 3 HIGH', 

54 6: 'Input 1 HIGH', 

55 7: 'Input 0 HIGH', 

56 8: 'Output 0 signal HIGH + Stimulation', 

57 9: 'Output 0 signal HIGH + Input 0 signal HIGH', 

58 10: 'Input 0 signal HIGH + Stimulation', 

59 11: 'Output 0 HIGH + Input 0 HIGH + Stimulation', 

60 }, 

61 'No LED ON': {0: 0, 1: 8, 2: 16, 3: 32, 4: 64, 5: 128, 6: 256, 7: 512, 8: 48, 9: 528, 10: 544, 11: 560}, 

62 'L415': {0: 1, 1: 9, 2: 17, 3: 33, 4: 65, 5: 129, 6: 257, 7: 513, 8: 49, 9: 529, 10: 545, 11: 561}, 

63 'L470': {0: 2, 1: 10, 2: 18, 3: 34, 4: 66, 5: 130, 6: 258, 7: 514, 8: 50, 9: 530, 10: 546, 11: 562}, 

64 'L560': {0: 4, 1: 12, 2: 20, 3: 36, 4: 68, 5: 132, 6: 260, 7: 516, 8: 52, 9: 532, 10: 548, 11: 564} 

65} 

66 

67 

68def sync_photometry_to_daq(vdaq, fs, df_photometry, chmap=DAQ_CHMAP, v_threshold=V_THRESHOLD): 

69 """ 

70 :param vdaq: dictionary of daq traces. 

71 :param fs: sampling frequency 

72 :param df_photometry: 

73 :param chmap: 

74 :param v_threshold: 

75 :return: 

76 """ 

77 # here we take the flag that is the most common 

78 daq_frames, tag_daq_frames = read_daq_timestamps(vdaq=vdaq, v_threshold=v_threshold) 1a

79 nf = np.minimum(tag_daq_frames.size, df_photometry['Input0'].size) 1a

80 

81 # we compute the framecounter for the DAQ, and match the bpod up state frame by frame for different shifts 

82 # the shift that minimizes the mismatch is usually good 

83 df = np.median(np.diff(df_photometry['Timestamp'])) 1a

84 fc = np.cumsum(np.round(np.diff(daq_frames) / fs / df).astype(np.int32)) - 1 # this is a daq frame counter 1a

85 fc = fc[fc < (nf - 1)] 1a

86 max_shift = 300 1a

87 error = np.zeros(max_shift * 2 + 1) 1a

88 shifts = np.arange(-max_shift, max_shift + 1) 1a

89 for i, shift in enumerate(shifts): 1a

90 rolled_fp = np.roll(df_photometry['Input0'].values[fc], shift) 1a

91 error[i] = np.sum(np.abs(rolled_fp - tag_daq_frames[:fc.size])) 1a

92 # a negative shift means that the DAQ is ahead of the photometry and that the DAQ misses frame at the beginning 

93 frame_shift = shifts[np.argmax(-error)] 1a

94 if np.sign(frame_shift) == -1: 1a

95 ifp = fc[np.abs(frame_shift):] 

96 elif np.sign(frame_shift) == 0: 1a

97 ifp = fc 1a

98 elif np.sign(frame_shift) == 1: 

99 ifp = fc[:-np.abs(frame_shift)] 

100 t_photometry = df_photometry['Timestamp'].values[ifp] 1a

101 t_daq = daq_frames[:ifp.size] / fs 1a

102 # import matplotlib.pyplot as plt 

103 # plt.plot(shifts, -error) 

104 fcn_fp2daq = scipy.interpolate.interp1d(t_photometry, t_daq, fill_value='extrapolate') 1a

105 drift_ppm = (np.polyfit(t_daq, t_photometry, 1)[0] - 1) * 1e6 1a

106 if drift_ppm > 120: 1a

107 _logger.warning(f"drift photometry to DAQ PPM: {drift_ppm}") 

108 else: 

109 _logger.info(f"drift photometry to DAQ PPM: {drift_ppm}") 1a

110 # here is a bunch of safeguards 

111 assert np.unique(np.diff(df_photometry['FrameCounter'])).size == 1 # checks that there are no missed frames on photo 1a

112 assert np.abs(frame_shift) <= 5 # it's always the end frames that are missing 1a

113 assert np.abs(drift_ppm) < 60 1a

114 ts_daq = fcn_fp2daq(df_photometry['Timestamp'].values) # those are the timestamps in daq time 1a

115 return ts_daq, fcn_fp2daq, drift_ppm 1a

116 

117 

118def read_daq_voltage(daq_file, chmap=DAQ_CHMAP): 

119 channel_names = [c.name for c in load_raw_daq_tdms(daq_file)['Analog'].channels()] 1da

120 assert all([v in channel_names for v in chmap.values()]), "Missing channel" 1da

121 vdaq, fs = load_channels_tdms(daq_file, chmap=chmap) 1da

122 vdaq = {k: v - np.median(v) for k, v in vdaq.items()} 1da

123 return vdaq, fs 1da

124 

125 

126def read_daq_timestamps(vdaq, v_threshold=V_THRESHOLD): 

127 """ 

128 From a tdms daq file, extracts the photometry frames and their tagging. 

129 :param vsaq: dictionary of the voltage traces from the DAQ. Each item has a key describing 

130 the channel as per the channel map, and contains a single voltage trace. 

131 :param v_threshold: 

132 :return: 

133 """ 

134 daq_frames = rises(vdaq['photometry'], step=v_threshold, analog=True) 1da

135 if daq_frames.size == 0: 1da

136 daq_frames = rises(-vdaq['photometry'], step=v_threshold, analog=True) 

137 _logger.warning(f'No photometry pulses detected, attempting to reverse voltage and detect again,' 

138 f'found {daq_frames.size} in reverse voltage. CHECK YOUR FP WIRING TO THE DAQ !!') 

139 tagged_frames = vdaq['bpod'][daq_frames] > v_threshold 1da

140 return daq_frames, tagged_frames 1da

141 

142 

143def check_timestamps(daq_file, photometry_file, tolerance=20, chmap=DAQ_CHMAP, v_threshold=V_THRESHOLD): 

144 """ 

145 Reads data file and checks that the number of timestamps check out with a tolerance of n_frames 

146 :param daq_file: 

147 :param photometry_file: 

148 :param tolerance: number of acceptable missing frames between the daq and the photometry file 

149 :param chmap: 

150 :param v_threshold: 

151 :return: None 

152 """ 

153 df_photometry = pd.read_csv(photometry_file) 1d

154 v, fs = read_daq_voltage(daq_file=daq_file, chmap=chmap) 1d

155 daq_frames, _ = read_daq_timestamps(vdaq=v, v_threshold=v_threshold) 1d

156 assert (daq_frames.shape[0] - df_photometry.shape[0]) < tolerance 1d

157 _logger.info(f"{daq_frames.shape[0] - df_photometry.shape[0]} frames difference, " 1d

158 f"{'/'.join(daq_file.parts[-2:])}: {daq_frames.shape[0]} frames, " 

159 f"{'/'.join(photometry_file.parts[-2:])}: {df_photometry.shape[0]}") 

160 

161 

162class FibrePhotometry(BaseExtractor): 

163 """ 

164 FibrePhotometry(self.session_path, collection=self.collection) 

165 """ 

166 save_names = ('photometry.signal.pqt') 

167 var_names = ('df_out') 

168 

169 def __init__(self, *args, collection='raw_photometry_data', **kwargs): 

170 """An extractor for all Neurophotometrics fibrephotometry data""" 

171 self.collection = collection 1ca

172 super().__init__(*args, **kwargs) 1ca

173 

174 @staticmethod 

175 def _channel_meta(light_source_map=None): 

176 """ 

177 Return table of light source wavelengths and corresponding colour labels. 

178 

179 Parameters 

180 ---------- 

181 light_source_map : dict 

182 An optional map of light source wavelengths (nm) used and their corresponding colour name. 

183 

184 Returns 

185 ------- 

186 pandas.DataFrame 

187 A sorted table of wavelength and colour name. 

188 """ 

189 light_source_map = light_source_map or LIGHT_SOURCE_MAP 1ca

190 meta = pd.DataFrame.from_dict(light_source_map) 1ca

191 meta.index.rename('channel_id', inplace=True) 1ca

192 return meta 1ca

193 

194 def _extract(self, light_source_map=None, collection=None, regions=None, **kwargs): 

195 """ 

196 

197 Parameters 

198 ---------- 

199 regions: list of str 

200 The list of regions to extract. If None extracts all columns containing "Region". Defaults to None. 

201 light_source_map : dict 

202 An optional map of light source wavelengths (nm) used and their corresponding colour name. 

203 collection: str / pathlib.Path 

204 An optional relative path from the session root folder to find the raw photometry data. 

205 Defaults to `raw_photometry_data` 

206 

207 Returns 

208 ------- 

209 numpy.ndarray 

210 A 1D array of signal values. 

211 numpy.ndarray 

212 A 1D array of ints corresponding to the active light source during a given frame. 

213 pandas.DataFrame 

214 A table of intensity for each region, with associated times, wavelengths, names and colors 

215 """ 

216 collection = collection or self.collection 1ca

217 fp_data = alfio.load_object(self.session_path / collection, 'fpData') 1ca

218 ts = self.extract_timestamps(fp_data['raw'], **kwargs) 1ca

219 

220 # Load channels and 

221 channel_meta_map = self._channel_meta(kwargs.get('light_source_map')) 1ca

222 led_states = fp_data.get('channels', pd.DataFrame(NEUROPHOTOMETRICS_LED_STATES)) 1ca

223 led_states = led_states.set_index('Condition') 1ca

224 # Extract signal columns into 2D array 

225 regions = regions or [k for k in fp_data['raw'].keys() if 'Region' in k] 1ca

226 out_df = fp_data['raw'].filter(items=regions, axis=1).sort_index(axis=1) 1ca

227 out_df['times'] = ts 1ca

228 out_df['wavelength'] = np.NaN 1ca

229 out_df['name'] = '' 1ca

230 out_df['color'] = '' 1ca

231 # Extract channel index 

232 states = fp_data['raw'].get('LedState', fp_data['raw'].get('Flags', None)) 1ca

233 for state in states.unique(): 1ca

234 ir, ic = np.where(led_states == state) 1ca

235 if ic.size == 0: 1ca

236 continue 1a

237 for cn in ['name', 'color', 'wavelength']: 1ca

238 out_df.loc[states == state, cn] = channel_meta_map.iloc[ic[0]][cn] 1ca

239 return out_df 1ca

240 

241 def extract_timestamps(self, fp_data, **kwargs): 

242 """Extract the photometry.timestamps array. 

243 

244 This depends on the DAQ and task synchronization protocol. 

245 

246 Parameters 

247 ---------- 

248 fp_data : dict 

249 A Bunch of raw fibrephotometry data, with the keys ('raw', 'channels'). 

250 

251 Returns 

252 ------- 

253 numpy.ndarray 

254 An array of timestamps, one per frame. 

255 """ 

256 daq_file = next(self.session_path.joinpath(self.collection).glob('*.tdms')) 1a

257 vdaq, fs = read_daq_voltage(daq_file, chmap=DAQ_CHMAP) 1a

258 ts, fcn_daq2_, drift_ppm = sync_photometry_to_daq( 1a

259 vdaq=vdaq, fs=fs, df_photometry=fp_data, v_threshold=V_THRESHOLD) 

260 gc_bpod, _ = GoCueTriggerTimes(session_path=self.session_path).extract(task_collection='raw_behavior_data', save=False) 1a

261 gc_daq = rises(vdaq['bpod']) 1a

262 

263 fcn_daq2_bpod, drift_ppm, idaq, ibp = sync_timestamps( 1a

264 rises(vdaq['bpod']) / fs, gc_bpod, return_indices=True) 

265 assert drift_ppm < 100, f"Drift between bpod and daq is above 100 ppm: {drift_ppm}" 1a

266 assert (gc_daq.size - idaq.size) < 5, "Bpod and daq synchronisation failed as too few" \ 1a

267 "events could be matched" 

268 ts = fcn_daq2_bpod(ts) 1a

269 return ts 1a