Coverage for ibllib/io/extractors/biased_trials.py: 94%

104 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1from pathlib import Path, PureWindowsPath 

2 

3from pkg_resources import parse_version 

4import numpy as np 

5from one.alf.io import AlfBunch 

6 

7from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

8import ibllib.io.raw_data_loaders as raw 

9from ibllib.io.extractors.training_trials import ( 

10 Choice, FeedbackTimes, FeedbackType, GoCueTimes, GoCueTriggerTimes, 

11 IncludedTrials, Intervals, ProbabilityLeft, ResponseTimes, RewardVolume, 

12 StimOnTimes_deprecated, StimOnTriggerTimes, StimOnOffFreezeTimes, ItiInTimes, 

13 StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, PhasePosQuiescence) 

14from ibllib.io.extractors.training_wheel import Wheel 

15 

16__all__ = ['extract_all', 'BiasedTrials', 'EphysTrials'] 

17 

18 

19class ContrastLR(BaseBpodTrialsExtractor): 

20 """ 

21 Get left and right contrasts from raw datafile. 

22 """ 

23 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy') 

24 var_names = ('contrastLeft', 'contrastRight') 

25 

26 def _extract(self, **kwargs): 

27 contrastLeft = np.array([t['contrast'] if np.sign( 1amwnsuxhtvcdeiopqjkrbfg

28 t['position']) < 0 else np.nan for t in self.bpod_trials]) 

29 contrastRight = np.array([t['contrast'] if np.sign( 1amwnsuxhtvcdeiopqjkrbfg

30 t['position']) > 0 else np.nan for t in self.bpod_trials]) 

31 return contrastLeft, contrastRight 1amwnsuxhtvcdeiopqjkrbfg

32 

33 

34class ProbaContrasts(BaseBpodTrialsExtractor): 

35 """ 

36 Bpod pre-generated values for probabilityLeft, contrastLR, phase, quiescence 

37 """ 

38 save_names = ('_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy', None, None, 

39 '_ibl_trials.probabilityLeft.npy', '_ibl_trials.quiescencePeriod.npy') 

40 var_names = ('contrastLeft', 'contrastRight', 'phase', 

41 'position', 'probabilityLeft', 'quiescence') 

42 

43 def _extract(self, **kwargs): 

44 """Extracts positions, contrasts, quiescent delay, stimulus phase and probability left 

45 from pregenerated session files. Used in ephysChoiceWorld extractions. 

46 Optional: saves alf contrastLR and probabilityLeft npy files""" 

47 pe = self.get_pregenerated_events(self.bpod_trials, self.settings) 1alhcdeijkbfg

48 return [pe[k] for k in sorted(pe.keys())] 1alhcdeijkbfg

49 

50 @staticmethod 

51 def get_pregenerated_events(bpod_trials, settings): 

52 num = settings.get("PRELOADED_SESSION_NUM", None) 1alhcdeijkbfg

53 if num is None: 1alhcdeijkbfg

54 num = settings.get("PREGENERATED_SESSION_NUM", None) 1ahcdeijkbfg

55 if num is None: 1alhcdeijkbfg

56 fn = settings.get('SESSION_LOADED_FILE_PATH', '') 1acdejkbfg

57 fn = PureWindowsPath(fn).name 1acdejkbfg

58 num = ''.join([d for d in fn if d.isdigit()]) 1acdejkbfg

59 if num == '': 1acdejkbfg

60 raise ValueError("Can't extract left probability behaviour.") 

61 # Load the pregenerated file 

62 ntrials = len(bpod_trials) 1alhcdeijkbfg

63 sessions_folder = Path(raw.__file__).parent.joinpath( 1alhcdeijkbfg

64 "extractors", "ephys_sessions") 

65 fname = f"session_{num}_ephys_pcqs.npy" 1alhcdeijkbfg

66 pcqsp = np.load(sessions_folder.joinpath(fname)) 1alhcdeijkbfg

67 pos = pcqsp[:, 0] 1alhcdeijkbfg

68 con = pcqsp[:, 1] 1alhcdeijkbfg

69 pos = pos[: ntrials] 1alhcdeijkbfg

70 con = con[: ntrials] 1alhcdeijkbfg

71 contrastRight = con.copy() 1alhcdeijkbfg

72 contrastLeft = con.copy() 1alhcdeijkbfg

73 contrastRight[pos < 0] = np.nan 1alhcdeijkbfg

74 contrastLeft[pos > 0] = np.nan 1alhcdeijkbfg

75 qui = pcqsp[:, 2] 1alhcdeijkbfg

76 qui = qui[: ntrials] 1alhcdeijkbfg

77 phase = pcqsp[:, 3] 1alhcdeijkbfg

78 phase = phase[: ntrials] 1alhcdeijkbfg

79 pLeft = pcqsp[:, 4] 1alhcdeijkbfg

80 pLeft = pLeft[: ntrials] 1alhcdeijkbfg

81 

82 phase_path = sessions_folder.joinpath(f"session_{num}_stim_phase.npy") 1alhcdeijkbfg

83 is_patched_version = parse_version( 1alhcdeijkbfg

84 settings.get('IBLRIG_VERSION_TAG', 0)) > parse_version('6.4.0') 

85 if phase_path.exists() and is_patched_version: 1alhcdeijkbfg

86 phase = np.load(phase_path)[:ntrials] 1lhi

87 

88 return {'position': pos, 'quiescence': qui, 'phase': phase, 'probabilityLeft': pLeft, 1alhcdeijkbfg

89 'contrastRight': contrastRight, 'contrastLeft': contrastLeft} 

90 

91 

92class TrialsTableBiased(BaseBpodTrialsExtractor): 

93 """ 

94 Extracts the following into a table from Bpod raw data: 

95 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, 

96 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times 

97 Additionally extracts the following wheel data: 

98 wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude 

99 """ 

100 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

101 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None) 

102 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', 

103 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement') 

104 

105 def _extract(self, extractor_classes=None, **kwargs): 

106 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ContrastLR, FeedbackTimes, FeedbackType, 1mnsuhtvcdeiopqrbfg

107 RewardVolume, ProbabilityLeft, Wheel] 

108 out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, 1mnsuhtvcdeiopqrbfg

109 save=False, task_collection=self.task_collection) 

110 

111 table = AlfBunch({k: out.pop(k) for k in list(out.keys()) if k not in self.var_names}) 1mnsuhtvcdeiopqrbfg

112 assert len(table.keys()) == 12 1mnsuhtvcdeiopqrbfg

113 

114 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1mnsuhtvcdeiopqrbfg

115 

116 

117class TrialsTableEphys(BaseBpodTrialsExtractor): 

118 """ 

119 Extracts the following into a table from Bpod raw data: 

120 intervals, goCue_times, response_times, choice, stimOn_times, contrastLeft, contrastRight, 

121 feedback_times, feedbackType, rewardVolume, probabilityLeft, firstMovement_times 

122 Additionally extracts the following wheel data: 

123 wheel_timestamps, wheel_position, wheel_moves_intervals, wheel_moves_peak_amplitude 

124 """ 

125 save_names = ('_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

126 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, 

127 None, None, None, '_ibl_trials.quiescencePeriod.npy') 

128 var_names = ('table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 'wheel_moves_intervals', 

129 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 

130 'phase', 'position', 'quiescence') 

131 

132 def _extract(self, extractor_classes=None, **kwargs): 

133 base = [Intervals, GoCueTimes, ResponseTimes, Choice, StimOnOffFreezeTimes, ProbaContrasts, 1ahcdeijkbfg

134 FeedbackTimes, FeedbackType, RewardVolume, Wheel] 

135 # Exclude from trials table 

136 out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, 1ahcdeijkbfg

137 save=False, task_collection=self.task_collection) 

138 table = AlfBunch({k: v for k, v in out.items() if k not in self.var_names}) 1ahcdeijkbfg

139 assert len(table.keys()) == 12 1ahcdeijkbfg

140 

141 return table.to_df(), *(out.pop(x) for x in self.var_names if x != 'table') 1ahcdeijkbfg

142 

143 

144class BiasedTrials(BaseBpodTrialsExtractor): 

145 """ 

146 Same as training_trials.TrainingTrials except... 

147 - there is no RepNum 

148 - ContrastLR is extracted differently 

149 - IncludedTrials is only extracted for 5.0.0 or greater 

150 """ 

151 save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, None, None, None, 

152 '_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

153 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, '_ibl_trials.included.npy', 

154 None, None, '_ibl_trials.quiescencePeriod.npy') 

155 var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 

156 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 

157 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included', 

158 'phase', 'position', 'quiescence') 

159 

160 def _extract(self, extractor_classes=None, **kwargs): 

161 base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 1mnsuhtvcdeiopqrbfg

162 ErrorCueTriggerTimes, TrialsTableBiased, IncludedTrials, PhasePosQuiescence] 

163 # Exclude from trials table 

164 out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, 1mnsuhtvcdeiopqrbfg

165 save=False, task_collection=self.task_collection) 

166 return tuple(out.pop(x) for x in self.var_names) 1mnsuhtvcdeiopqrbfg

167 

168 

169class EphysTrials(BaseBpodTrialsExtractor): 

170 """ 

171 Same as BiasedTrials except... 

172 - Contrast, phase, position, probabilityLeft and quiescence is extracted differently 

173 """ 

174 save_names = ('_ibl_trials.goCueTrigger_times.npy', '_ibl_trials.stimOnTrigger_times.npy', None, None, None, None, 

175 '_ibl_trials.table.pqt', None, None, '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', 

176 '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy', None, None, 

177 '_ibl_trials.included.npy', None, None, '_ibl_trials.quiescencePeriod.npy') 

178 var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 

179 'errorCueTrigger_times', 'table', 'stimOff_times', 'stimFreeze_times', 'wheel_timestamps', 'wheel_position', 

180 'wheel_moves_intervals', 'wheel_moves_peak_amplitude', 'peakVelocity_times', 'is_final_movement', 'included', 

181 'phase', 'position', 'quiescence') 

182 

183 def _extract(self, extractor_classes=None, **kwargs): 

184 base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, 

185 ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence] 

186 # Exclude from trials table 

187 out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, 

188 save=False, task_collection=self.task_collection) 

189 return tuple(out.pop(x) for x in self.var_names) 

190 

191 

192def extract_all(session_path, save=False, bpod_trials=False, settings=False, extra_classes=None, 

193 task_collection='raw_behavior_data', save_path=None): 

194 """ 

195 Same as training_trials.extract_all except... 

196 - there is no RepNum 

197 - ContrastLR is extracted differently 

198 - IncludedTrials is only extracted for 5.0.0 or greater 

199 

200 :param session_path: 

201 :param save: 

202 :param bpod_trials: 

203 :param settings: 

204 :param extra_classes: additional BaseBpodTrialsExtractor subclasses for custom extractions 

205 :return: 

206 """ 

207 if not bpod_trials: 1amnshtcdeiopqjkrbfg

208 bpod_trials = raw.load_data(session_path, task_collection=task_collection) 1m

209 if not settings: 1amnshtcdeiopqjkrbfg

210 settings = raw.load_settings(session_path, task_collection=task_collection) 1m

211 if settings is None: 1amnshtcdeiopqjkrbfg

212 settings = {'IBLRIG_VERSION_TAG': '100.0.0'} 

213 

214 if settings['IBLRIG_VERSION_TAG'] == '': 1amnshtcdeiopqjkrbfg

215 settings['IBLRIG_VERSION_TAG'] = '100.0.0' 

216 

217 # Version check 

218 if parse_version(settings['IBLRIG_VERSION_TAG']) >= parse_version('5.0.0'): 1amnshtcdeiopqjkrbfg

219 # We now extract a single trials table 

220 base = [BiasedTrials] 1mnshcdeiopqrbfg

221 else: 

222 base = [ 1amntjkrb

223 GoCueTriggerTimes, Intervals, Wheel, FeedbackType, ContrastLR, ProbabilityLeft, Choice, 

224 StimOnTimes_deprecated, RewardVolume, FeedbackTimes, ResponseTimes, GoCueTimes, PhasePosQuiescence 

225 ] 

226 

227 if extra_classes: 1amnshtcdeiopqjkrbfg

228 base.extend(extra_classes) 1ahcdeiopqjkbfg

229 

230 out, fil = run_extractor_classes(base, save=save, session_path=session_path, bpod_trials=bpod_trials, settings=settings, 1amnshtcdeiopqjkrbfg

231 task_collection=task_collection, path_out=save_path) 

232 return out, fil 1amnshtcdeiopqjkrbfg