Coverage for ibllib/io/extractors/habituation_trials.py: 89%

72 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""Habituation ChoiceWorld Bpod trials extraction.""" 

2import logging 

3import numpy as np 

4 

5from packaging import version 

6 

7import ibllib.io.raw_data_loaders as raw 

8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

9from ibllib.io.extractors.biased_trials import ContrastLR 

10from ibllib.io.extractors.training_trials import FeedbackTimes, StimOnTriggerTimes, GoCueTimes 

11 

12_logger = logging.getLogger(__name__) 

13 

14 

15class HabituationTrials(BaseBpodTrialsExtractor): 

16 var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight', 

17 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals', 

18 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 

19 'stimCenterTrigger_times', 'stimCenter_times', 'position', 'phase') 

20 

21 def __init__(self, *args, **kwargs): 

22 super().__init__(*args, **kwargs) 1ba

23 exclude = ['itiIn_times', 'stimOffTrigger_times', 'stimCenter_times', 1ba

24 'stimCenterTrigger_times', 'position', 'phase'] 

25 self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names) 1ba

26 

27 def _extract(self) -> dict: 

28 """ 

29 Extract the Bpod trial events. 

30 

31 For iblrig versions < 8.13 the Bpod state machine for this task had extremely misleading names! 

32 The 'iti' state was actually the delay between valve close and trial end (the stimulus is 

33 still present during this period), and the 'trial_start' state is actually the ITI during 

34 which there is a 1s Bpod TTL and gray screen period. 

35 

36 In version 8.13 and later, the 'iti' state was renamed to 'post_reward' and 'trial_start' 

37 was renamed to 'iti'. 

38 

39 Returns 

40 ------- 

41 dict 

42 A dictionary of Bpod trial events. The keys are defined in the `var_names` attribute. 

43 """ 

44 # Extract all trials... 

45 

46 # Get all detected TTLs. These are stored for QC purposes 

47 self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) 1ba

48 # These are the frame2TTL pulses as a list of lists, one per trial 

49 ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials] 1ba

50 

51 # Report missing events 

52 n_missing = sum(len(pulses) != 3 for pulses in ttls) 1ba

53 # Check if all stim syncs have failed to be detected 

54 if n_missing == len(ttls): 1ba

55 _logger.error(f'{self.session_path}: Missing ALL BNC1 TTLs ({n_missing} trials)') 

56 elif n_missing > 0: # Check if any stim_sync has failed be detected for every trial 1ba

57 _logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)') 1a

58 

59 # Extract datasets common to trainingChoiceWorld 

60 training = [ContrastLR, FeedbackTimes, GoCueTimes, StimOnTriggerTimes] 1ba

61 out, _ = run_extractor_classes(training, session_path=self.session_path, save=False, 1ba

62 bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) 

63 

64 """ 1ba

65 The 'trial_start'/'iti' state is in fact the 1s grey screen period, therefore the first 

66 timestamp is really the end of the previous trial and also the stimOff trigger time. The 

67 second timestamp is the true trial start time. This state was renamed in version 8.13. 

68 """ 

69 state_names = self.bpod_trials[0]['behavior_data']['States timestamps'].keys() 1ba

70 rig_version = version.parse(self.settings['IBLRIG_VERSION']) 1ba

71 legacy_state_machine = 'post_reward' not in state_names and 'trial_start' in state_names 1ba

72 

73 key = 'iti' if (rig_version >= version.parse('8.13') and not legacy_state_machine) else 'trial_start' 1ba

74 (_, *ends), starts = zip(*[ 1ba

75 t['behavior_data']['States timestamps'][key][-1] for t in self.bpod_trials] 

76 ) 

77 

78 # StimOffTrigger times 

79 out['stimOffTrigger_times'] = np.array(ends) 1ba

80 

81 # StimOff times 

82 """ 1ba

83 There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse. 

84 If 1 or more pulses are missing, we can not be confident of assigning the correct one. 

85 """ 

86 out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan for sync in ttls[1:]]) 1ba

87 

88 # Trial intervals 

89 """ 1ba

90 In terms of TTLs, the intervals are defined by the 'trial_start' state, however the stim 

91 off time often happens after the trial end TTL front, i.e. after the 'trial_start' start 

92 begins. For these trials, we set the trial end time as the stim off time. 

93 """ 

94 # NB: We lose the last trial because the stim off event occurs at trial_num + 1 

95 n_trials = out['stimOff_times'].size 1ba

96 out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :] 1ba

97 

98 to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1]) 1ba

99 if np.any(to_correct): 1ba

100 _logger.debug( 1ba

101 '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end', 

102 sum(to_correct), len(to_correct)) 

103 out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct] 1ba

104 

105 # itiIn times 

106 out['itiIn_times'] = np.r_[ends, np.nan] 1ba

107 

108 # GoCueTriggerTimes is the same event as StimOnTriggerTimes 

109 out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy() 1ba

110 

111 # StimCenterTrigger times 

112 # Get the stim_on_state that triggers the onset of the stim 

113 stim_center_state = np.array([tr['behavior_data']['States timestamps'] 1ba

114 ['stim_center'][0] for tr in self.bpod_trials]) 

115 out['stimCenterTrigger_times'] = stim_center_state[:, 0].T 1ba

116 

117 # StimCenter times 

118 stim_center_times = np.full(out['stimCenterTrigger_times'].shape, np.nan) 1ba

119 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1ba

120 """We expect there to be 3 pulses per trial; if this is the case, stim center will 1ba

121 be the third pulse. If any pulses are missing, we can only be confident of the correct 

122 one if exactly one pulse occurs after the stim center trigger""" 

123 if len(sync) == 3 or (len(sync) > 0 and sum(pulse > last for pulse in sync) == 1): 1ba

124 stim_center_times[i] = sync[-1] 1ba

125 out['stimCenter_times'] = stim_center_times 1ba

126 

127 # StimOn times 

128 stimOn_times = np.full(out['stimOnTrigger_times'].shape, np.nan) 1ba

129 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1ba

130 """We expect there to be 3 pulses per trial; if this is the case, stim on will be the 1ba

131 second pulse. If 1 pulse is missing, we can only be confident of the correct one if 

132 both pulses occur before the stim center trigger""" 

133 if len(sync) == 3 or (len(sync) == 2 and sum(pulse < last for pulse in sync) == 2): 1ba

134 stimOn_times[i] = sync[1] 1ba

135 out['stimOn_times'] = stimOn_times 1ba

136 

137 # RewardVolume 

138 trial_volume = [x['reward_amount'] for x in self.bpod_trials] 1ba

139 out['rewardVolume'] = np.array(trial_volume).astype(np.float64) 1ba

140 

141 # FeedbackType is always positive 

142 out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8) 1ba

143 

144 # Phase and position 

145 out['position'] = np.array([t['position'] for t in self.bpod_trials]) 1ba

146 out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials]) 1ba

147 

148 # Double-check that the early and late trial events occur within the trial intervals 

149 idx = ~np.isnan(out['stimOn_times'][:n_trials]) 1ba

150 assert not np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]), \ 1ba

151 'Stim on events occurring outside trial intervals' 

152 

153 # Truncate arrays and return in correct order 

154 return {k: out[k][:n_trials] for k in self.var_names} 1ba

155 

156 

157def extract_all(session_path, save=False, bpod_trials=False, settings=False, task_collection='raw_behavior_data', save_path=None): 

158 """Extract all datasets from habituationChoiceWorld 

159 Note: only the datasets from the HabituationTrials extractor will be saved to disc. 

160 

161 :param session_path: The session path where the raw data are saved 

162 :param save: If True, the datasets that are considered standard are saved to the session path 

163 :param bpod_trials: The raw Bpod trial data 

164 :param settings: The raw Bpod sessions 

165 :returns: a dict of datasets and a corresponding list of file names 

166 """ 

167 if not bpod_trials: 

168 bpod_trials = raw.load_data(session_path, task_collection=task_collection) 

169 if not settings: 

170 settings = raw.load_settings(session_path, task_collection=task_collection) 

171 

172 # Standard datasets that may be saved as ALFs 

173 params = dict(session_path=session_path, bpod_trials=bpod_trials, settings=settings, task_collection=task_collection, 

174 path_out=save_path) 

175 out, fil = run_extractor_classes(HabituationTrials, save=save, **params) 

176 return out, fil