Coverage for ibllib/io/extractors/biased_trials.py: 99%

98 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1from pathlib import Path, PureWindowsPath 

2 

3from packaging import version 

4import numpy as np 

5from one.alf.io import AlfBunch 

6 

7from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

8import ibllib.io.raw_data_loaders as raw 

9from ibllib.io.extractors.training_trials import ( 

10 Choice, FeedbackTimes, FeedbackType, GoCueTimes, GoCueTriggerTimes, 

11 IncludedTrials, Intervals, ProbabilityLeft, ResponseTimes, RewardVolume, 

12 StimOnTriggerTimes, StimOnOffFreezeTimes, ItiInTimes, 

13 StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, PhasePosQuiescence) 

14from ibllib.io.extractors.training_wheel import Wheel 

15 

16__all__ = ['BiasedTrials', 'EphysTrials'] 

17 

18 

19class ContrastLR(BaseBpodTrialsExtractor): 

20 """Get left and right contrasts from raw datafile.""" 

21 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy') 

22 var_names = ('contrastLeft', 'contrastRight') 

23 

24 def _extract(self, **kwargs): 

25 contrastLeft = np.array([t['contrast'] if np.sign( 1plmqnocdb

26 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

27 contrastRight = np.array([t['contrast'] if np.sign( 1plmqnocdb

28 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

29 return contrastLeft, contrastRight 1plmqnocdb

30 

31 

32class ProbaContrasts(BaseBpodTrialsExtractor): 

33 """Bpod pre-generated values for probabilityLeft, contrastLR, phase, quiescence.""" 

34 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy', None, None, 

35 '_ibl_trials.probabilityLeft.npy', '_ibl_trials.quiescencePeriod.npy') 

36 var_names = ('contrastLeft', 'contrastRight', 'phase', 

37 'position', 'probabilityLeft', 'quiescence') 

38 

39 def _extract(self, **kwargs): 

40 """Extracts positions, contrasts, quiescent delay, stimulus phase and probability left 

41 from pregenerated session files. Used in ephysChoiceWorld extractions. 

42 Optional: saves alf contrastLR and probabilityLeft npy files""" 

43 pe = self.get_pregenerated_events(self.bpod_trials, self.settings) 1akicjfgdbeh

44 return [pe[k] for k in sorted(pe.keys())] 1akicjfgdbeh

45 

46 @staticmethod 

47 def get_pregenerated_events(bpod_trials, settings): 

48 for k in ['PRELOADED_SESSION_NUM', 'PREGENERATED_SESSION_NUM', 'SESSION_TEMPLATE_ID']: 1akicjfgdbeh

49 num = settings.get(k, None) 1akicjfgdbeh

50 if num is not None: 1akicjfgdbeh

51 break 1kijde

52 if num is None: 1akicjfgdbeh

53 fn = settings.get('SESSION_LOADED_FILE_PATH', '') 1acfgbeh

54 fn = PureWindowsPath(fn).name 1acfgbeh

55 num = ''.join([d for d in fn if d.isdigit()]) 1acfgbeh

56 if num == '': 1acfgbeh

57 raise ValueError("Can't extract left probability behaviour.") 

58 # Load the pregenerated file 

59 ntrials = len(bpod_trials) 1akicjfgdbeh

60 sessions_folder = Path(raw.__file__).parent.joinpath( 1akicjfgdbeh

61 "extractors", "ephys_sessions") 

62 fname = f"session_{num}_ephys_pcqs.npy" 1akicjfgdbeh

63 pcqsp = np.load(sessions_folder.joinpath(fname)) 1akicjfgdbeh

64 pos = pcqsp[:, 0] 1akicjfgdbeh

65 con = pcqsp[:, 1] 1akicjfgdbeh

66 pos = pos[: ntrials] 1akicjfgdbeh

67 con = con[: ntrials] 1akicjfgdbeh

68 contrastRight = con.copy() 1akicjfgdbeh

69 contrastLeft = con.copy() 1akicjfgdbeh

70 contrastRight[pos < 0] = np.nan 1akicjfgdbeh

71 contrastLeft[pos > 0] = np.nan 1akicjfgdbeh

72 qui = pcqsp[:, 2] 1akicjfgdbeh

73 qui = qui[: ntrials] 1akicjfgdbeh

74 phase = pcqsp[:, 3] 1akicjfgdbeh

75 phase = phase[: ntrials] 1akicjfgdbeh

76 pLeft = pcqsp[:, 4] 1akicjfgdbeh

77 pLeft = pLeft[: ntrials] 1akicjfgdbeh

78 

79 phase_path = sessions_folder.joinpath(f"session_{num}_stim_phase.npy") 1akicjfgdbeh

80 is_patched_version = version.parse( 1akicjfgdbeh

81 settings.get('IBLRIG_VERSION') or '0') > version.parse('6.4.0') 

82 if phase_path.exists() and is_patched_version: 1akicjfgdbeh

83 phase = np.load(phase_path)[:ntrials] 1kijd

84 

85 return {'position': pos, 'quiescence': qui, 'phase': phase, 'probabilityLeft': pLeft, 1akicjfgdbeh

86 'contrastRight': contrastRight, 'contrastLeft': contrastLeft} 

87 

88 

89class TrialsTableBiased(BaseBpodTrialsExtractor): 

90 """ 

91 Extracts the following into a table from Bpod raw data: 

92 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, 

93 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times 

94 Additionally extracts the following wheel data: 

95 wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude 

96 """ 

97 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

98 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) 

99 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 

100 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement') 

101 

102 def _extract(self, extractor_classes=None, **kwargs): 

103 extractor_classes = extractor_classes or [] 1lmnodb

104 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1lmnodb

105 RewardVolume, ProbabilityLeft, Wheel] 

106 out, _ = run_extractor_classes( 1lmnodb

107 base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials, 

108 settings=self.settings, save=False, task_collection=self.task_collection) 

109 

110 table = AlfBunch({k: out.pop(k) for k in list(out.keys()) if k not in self.var_names}) 1lmnodb

111 assert len(table.keys()) == 12 1lmnodb

112 

113 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1lmnodb

114 

115 

116class TrialsTableEphys(BaseBpodTrialsExtractor): 

117 """ 

118 Extracts the following into a table from Bpod raw data: 

119 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, 

120 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times 

121 Additionally extracts the following wheel data: 

122 wheel_timestamps, wheel_position, wheelMoves_intervals, wheelMoves_peakAmplitude 

123 """ 

124 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

125 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, 

126 None, None, None, '_ibl_trials.quiescencePeriod.npy') 

127 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 

128 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement', 

129 'phase', 'position', 'quiescence') 

130 

131 def _extract(self, extractor_classes=None, **kwargs): 

132 extractor_classes = extractor_classes or [] 1aicjfgbeh

133 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ProbaContrasts, 1aicjfgbeh

134 FeedbackTimes, FeedbackType, RewardVolume, Wheel] 

135 # Exclude from trials table 

136 out, _ = run_extractor_classes( 1aicjfgbeh

137 base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials, 

138 settings=self.settings, save=False, task_collection=self.task_collection) 

139 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1aicjfgbeh

140 assert len(table.keys()) == 12 1aicjfgbeh

141 

142 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1aicjfgbeh

143 

144 

145class BiasedTrials(BaseBpodTrialsExtractor): 

146 """ 

147 Same as training_trials.TrainingTrials except... 

148 - there is no RepNum 

149 - ContrastLR is extracted differently 

150 - IncludedTrials is only extracted for 5.0.0 or greater 

151 """ 

152 save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, 

153 '_ibl_trials.stimOffTrigger_times.npy', None, None, '_ibl_trials.table.pqt', 

154 '_ibl_trials.stimOff_times.npy', None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

155 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, 

156 '_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy') 

157 var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 

158 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 

159 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement', 

160 'included', 'phase', 'position', 'quiescence') 

161 

162 def _extract(self, extractor_classes=None, **kwargs) -> dict: 

163 extractor_classes = extractor_classes or [] 1lmnob

164 base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1lmnob

165 ErrorCueTriggerTimes, TrialsTableBiased, IncludedTrials, PhasePosQuiescence] 

166 # Exclude from trials table 

167 out, _ = run_extractor_classes( 1lmnob

168 base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials, 

169 settings=self.settings, save=False, task_collection=self.task_collection) 

170 return {k: out[k] for k in self.var_names} 1lmnob

171 

172 

173class EphysTrials(BaseBpodTrialsExtractor): 

174 """ 

175 Same as BiasedTrials except... 

176 - Contrast, phase, position, probabilityLeft and quiescence is extracted differently 

177 """ 

178 save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, 

179 '_ibl_trials.stimOffTrigger_times.npy', None, None, 

180 '_ibl_trials.table.pqt', '_ibl_trials.stimOff_times.npy', None, '_ibl_wheel.timestamps.npy', 

181 '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, 

182 '_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy') 

183 var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 

184 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 

185 'wheelMoves_intervals', 'wheelMoves_peakAmplitude', 'wheelMoves_peakVelocity_times', 'is_final_movement', 

186 'included', 'phase', 'position', 'quiescence') 

187 

188 def _extract(self, extractor_classes=None, **kwargs) -> dict: 

189 extractor_classes = extractor_classes or [] 1aicjfgdbeh

190 

191 # For iblrig v8 we use the biased trials table instead. ContrastLeft, ContrastRight and ProbabilityLeft are 

192 # filled from the values in the bpod data itself rather than using the pregenerated session number 

193 iblrig_version = self.settings.get('IBLRIG_VERSION', self.settings.get('IBLRIG_VERSION_TAG', '0')) 1aicjfgdbeh

194 if version.parse(iblrig_version) >= version.parse('8.0.0'): 1aicjfgdbeh

195 TrialsTable = TrialsTableBiased 1d

196 else: 

197 TrialsTable = TrialsTableEphys 1aicjfgbeh

198 

199 base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1aicjfgdbeh

200 ErrorCueTriggerTimes, TrialsTable, IncludedTrials, PhasePosQuiescence] 

201 # Get all detected TTLs. These are stored for QC purposes 

202 self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) 1aicjfgdbeh

203 # Exclude from trials table 

204 out, _ = run_extractor_classes( 1aicjfgdbeh

205 base + extractor_classes, session_path=self.session_path, bpod_trials=self.bpod_trials, 

206 settings=self.settings, save=False, task_collection=self.task_collection) 

207 return {k: out[k] for k in self.var_names} 1aicjfgdbeh