Coverage for ibllib/io/extractors/biased_trials.py: 97%

114 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1from pathlib import Path, PureWindowsPath 

2 

3from packaging import version 

4import numpy as np 

5from one.alf.io import AlfBunch 

6 

7from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

8import ibllib.io.raw_data_loaders as raw 

9from ibllib.io.extractors.training_trials import ( 

10 Choice, FeedbackTimes, FeedbackType, GoCueTimes, GoCueTriggerTimes, 

11 IncludedTrials, Intervals, ProbabilityLeft, ResponseTimes, RewardVolume, 

12 StimOnTimes_deprecated, StimOnTriggerTimes, StimOnOffFreezeTimes, ItiInTimes, 

13 StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, PhasePosQuiescence) 

14from ibllib.io.extractors.training_wheel import Wheel 

15 

16__all__ = ['extract_all', 'BiasedTrials', 'EphysTrials'] 

17 

18 

19class ContrastLR(BaseBpodTrialsExtractor): 

20 """Get left and right contrasts from raw datafile.""" 

21 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy') 

22 var_names = ('contrastLeft', 'contrastRight') 

23 

24 def _extract(self, **kwargs): 

25 contrastLeft = np.array([t['contrast'] if np.sign( 1amwnrsxtuefopqcdvb

26 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

27 contrastRight = np.array([t['contrast'] if np.sign( 1amwnrsxtuefopqcdvb

28 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

29 return contrastLeft, contrastRight 1amwnrsxtuefopqcdvb

30 

31 

32class ProbaContrasts(BaseBpodTrialsExtractor): 

33 """Bpod pre-generated values for probabilityLeft, contrastLR, phase, quiescence.""" 

34 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy', None, None, 

35 '_ibl_trials.probabilityLeft.npy', '_ibl_trials.quiescencePeriod.npy') 

36 var_names = ('contrastLeft', 'contrastRight', 'phase', 

37 'position', 'probabilityLeft', 'quiescence') 

38 

39 def _extract(self, **kwargs): 

40 """Extracts positions, contrasts, quiescent delay, stimulus phase and probability left 

41 from pregenerated session files. Used in ephysChoiceWorld extractions. 

42 Optional: saves alf contrastLR and probabilityLeft npy files""" 

43 pe = self.get_pregenerated_events(self.bpod_trials, self.settings) 1aljekhfcdbgi

44 return [pe[k] for k in sorted(pe.keys())] 1aljekhfcdbgi

45 

46 @staticmethod 

47 def get_pregenerated_events(bpod_trials, settings): 

48 for k in ['PRELOADED_SESSION_NUM', 'PREGENERATED_SESSION_NUM', 'SESSION_TEMPLATE_ID']: 1aljekhfcdbgi

49 num = settings.get(k, None) 1aljekhfcdbgi

50 if num is not None: 1aljekhfcdbgi

51 break 1ljkfg

52 if num is None: 1aljekhfcdbgi

53 fn = settings.get('SESSION_LOADED_FILE_PATH', '') 1aehcdbgi

54 fn = PureWindowsPath(fn).name 1aehcdbgi

55 num = ''.join([d for d in fn if d.isdigit()]) 1aehcdbgi

56 if num == '': 1aehcdbgi

57 raise ValueError("Can't extract left probability behaviour.") 

58 # Load the pregenerated file 

59 ntrials = len(bpod_trials) 1aljekhfcdbgi

60 sessions_folder = Path(raw.__file__).parent.joinpath( 1aljekhfcdbgi

61 "extractors", "ephys_sessions") 

62 fname = f"session_{num}_ephys_pcqs.npy" 1aljekhfcdbgi

63 pcqsp = np.load(sessions_folder.joinpath(fname)) 1aljekhfcdbgi

64 pos = pcqsp[:, 0] 1aljekhfcdbgi

65 con = pcqsp[:, 1] 1aljekhfcdbgi

66 pos = pos[: ntrials] 1aljekhfcdbgi

67 con = con[: ntrials] 1aljekhfcdbgi

68 contrastRight = con.copy() 1aljekhfcdbgi

69 contrastLeft = con.copy() 1aljekhfcdbgi

70 contrastRight[pos < 0] = np.nan 1aljekhfcdbgi

71 contrastLeft[pos > 0] = np.nan 1aljekhfcdbgi

72 qui = pcqsp[:, 2] 1aljekhfcdbgi

73 qui = qui[: ntrials] 1aljekhfcdbgi

74 phase = pcqsp[:, 3] 1aljekhfcdbgi

75 phase = phase[: ntrials] 1aljekhfcdbgi

76 pLeft = pcqsp[:, 4] 1aljekhfcdbgi

77 pLeft = pLeft[: ntrials] 1aljekhfcdbgi

78 

79 phase_path = sessions_folder.joinpath(f"session_{num}_stim_phase.npy") 1aljekhfcdbgi

80 is_patched_version = version.parse( 1aljekhfcdbgi

81 settings.get('IBLRIG_VERSION') or '0') > version.parse('6.4.0') 

82 if phase_path.exists() and is_patched_version: 1aljekhfcdbgi

83 phase = np.load(phase_path)[:ntrials] 1ljkf

84 

85 return {'position': pos, 'quiescence': qui, 'phase': phase, 'probabilityLeft': pLeft, 1aljekhfcdbgi

86 'contrastRight': contrastRight, 'contrastLeft': contrastLeft} 

87 

88 

89class TrialsTableBiased(BaseBpodTrialsExtractor): 

90 """ 

91 Extracts the following into a table from Bpod raw data: 

92 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, 

93 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times 

94 Additionally extracts the following wheel data: 

95 wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude 

96 """ 

97 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

98 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) 

99 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 

100 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement') 

101 

102 def _extract(self, extractor_classes=None, **kwargs): 

103 extractor_classes = extractor_classes or [] 1mnrstufopqvb

104 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1mnrstufopqvb

105 RewardVolume, ProbabilityLeft, Wheel] 

106 out, _ = run_extractor_classes( 1mnrstufopqvb

107 base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials, 

108 settings=self.settings, save=False, task_collection=self.task_collection) 

109 

110 table = AlfBunch({k: out.pop(k) for k in list(out.keys()) if k not in self.var_names}) 1mnrstufopqvb

111 assert len(table.keys()) == 12 1mnrstufopqvb

112 

113 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1mnrstufopqvb

114 

115 

116class TrialsTableEphys(BaseBpodTrialsExtractor): 

117 """ 

118 Extracts the following into a table from Bpod raw data: 

119 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, 

120 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times 

121 Additionally extracts the following wheel data: 

122 wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude 

123 """ 

124 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

125 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, 

126 None, None, None, '_ibl_trials.quiescencePeriod.npy') 

127 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 

128 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 

129 'phase', 'position', 'quiescence') 

130 

131 def _extract(self, extractor_classes=None, **kwargs): 

132 extractor_classes = extractor_classes or [] 1ajekhcdbgi

133 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ProbaContrasts, 1ajekhcdbgi

134 FeedbackTimes, FeedbackType, RewardVolume, Wheel] 

135 # Exclude from trials table 

136 out, _ = run_extractor_classes( 1ajekhcdbgi

137 base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials, 

138 settings=self.settings, save=False, task_collection=self.task_collection) 

139 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1ajekhcdbgi

140 assert len(table.keys()) == 12 1ajekhcdbgi

141 

142 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1ajekhcdbgi

143 

144 

145class BiasedTrials(BaseBpodTrialsExtractor): 

146 """ 

147 Same as training_trials.TrainingTrials except... 

148 - there is no RepNum 

149 - ContrastLR is extracted differently 

150 - IncludedTrials is only extracted for 5.0.0 or greater 

151 """ 

152 save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, None, None, None, 

153 '_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

154 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, '_ibl_trials.included.npy', 

155 None, None, '_ibl_trials.quiescencePeriod.npy') 

156 var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 

157 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 

158 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included', 

159 'phase', 'position', 'quiescence') 

160 

161 def _extract(self, extractor_classes=None, **kwargs) -> dict: 

162 extractor_classes = extractor_classes or [] 1mnrstuopqvb

163 base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1mnrstuopqvb

164 ErrorCueTriggerTimes, TrialsTableBiased, IncludedTrials, PhasePosQuiescence] 

165 # Exclude from trials table 

166 out, _ = run_extractor_classes( 1mnrstuopqvb

167 base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials, 

168 settings=self.settings, save=False, task_collection=self.task_collection) 

169 return {k: out[k] for k in self.var_names} 1mnrstuopqvb

170 

171 

172class EphysTrials(BaseBpodTrialsExtractor): 

173 """ 

174 Same as BiasedTrials except... 

175 - Contrast, phase, position, probabilityLeft and quiescence is extracted differently 

176 """ 

177 save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, None, None, None, 

178 '_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

179 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, 

180 '_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy') 

181 var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 

182 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 

183 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'peakVelocity_times', 'is_final_movement', 'included', 

184 'phase', 'position', 'quiescence') 

185 

186 def _extract(self, extractor_classes=None, **kwargs) -> dict: 

187 extractor_classes = extractor_classes or [] 1jekhfgi

188 

189 # For iblrig v8 we use the biased trials table instead. ContrastLeft, ContrastRight and ProbabilityLeft are 

190 # filled from the values in the bpod data itself rather than using the pregenerated session number 

191 iblrig_version = self.settings.get('IBLRIG_VERSION', self.settings.get('IBLRIG_VERSION_TAG', '0')) 1jekhfgi

192 if version.parse(iblrig_version) >= version.parse('8.0.0'): 1jekhfgi

193 TrialsTable = TrialsTableBiased 1f

194 else: 

195 TrialsTable = TrialsTableEphys 1jekhgi

196 

197 base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1jekhfgi

198 ErrorCueTriggerTimes, TrialsTable, IncludedTrials, PhasePosQuiescence] 

199 # Get all detected TTLs. These are stored for QC purposes 

200 self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) 1jekhfgi

201 # Exclude from trials table 

202 out, _ = run_extractor_classes( 1jekhfgi

203 base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials, 

204 settings=self.settings, save=False, task_collection=self.task_collection) 

205 return {k: out[k] for k in self.var_names} 1jekhfgi

206 

207 

208def extract_all(session_path, save=False, bpod_trials=False, settings=False, extra_classes=None, 

209 task_collection='raw_behavior_data', save_path=None): 

210 """ 

211 Same as training_trials.extract_all except... 

212 - there is no RepNum 

213 - ContrastLR is extracted differently 

214 - IncludedTrials is only extracted for 5.0.0 or greater 

215 

216 :param session_path: 

217 :param save: 

218 :param bpod_trials: 

219 :param settings: 

220 :param extra_classes: additional BaseBpodTrialsExtractor subclasses for custom extractions 

221 :return: 

222 """ 

223 if not bpod_trials: 1amnropqcdb

224 bpod_trials = raw.load_data(session_path, task_collection=task_collection) 1m

225 if not settings: 1amnropqcdb

226 settings = raw.load_settings(session_path, task_collection=task_collection) 1m

227 if settings is None: 1amnropqcdb

228 settings = {'IBLRIG_VERSION': '100.0.0'} 

229 

230 if settings['IBLRIG_VERSION'] == '': 1amnropqcdb

231 settings['IBLRIG_VERSION'] = '100.0.0' 

232 

233 # Version check 

234 if version.parse(settings['IBLRIG_VERSION']) >= version.parse('5.0.0'): 1amnropqcdb

235 # We now extract a single trials table 

236 base = [BiasedTrials] 1mnropqb

237 else: 

238 base = [ 1amncdb

239 GoCueTriggerTimes, Intervals, Wheel, FeedbackType, ContrastLR, ProbabilityLeft, Choice, 

240 StimOnTimes_deprecated, RewardVolume, FeedbackTimes, ResponseTimes, GoCueTimes, PhasePosQuiescence 

241 ] 

242 

243 if extra_classes: 1amnropqcdb

244 base.extend(extra_classes) 1aopqcdb

245 

246 out, fil = run_extractor_classes(base, save=save, session_path=session_path, bpod_trials=bpod_trials, settings=settings, 1amnropqcdb

247 task_collection=task_collection, path_out=save_path) 

248 return out, fil 1amnropqcdb