Coverage for ibllib/io/extractors/habituation_trials.py: 98%

64 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1"""Habituation ChoiceWorld Bpod trials extraction.""" 

2import logging 

3import numpy as np 

4 

5from packaging import version 

6 

7import ibllib.io.raw_data_loaders as raw 

8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

9from ibllib.io.extractors.biased_trials import ContrastLR 

10from ibllib.io.extractors.training_trials import FeedbackTimes, StimOnTriggerTimes, GoCueTimes 

11 

12_logger = logging.getLogger(__name__) 

13 

14 

15class HabituationTrials(BaseBpodTrialsExtractor): 

16 var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight', 

17 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals', 

18 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 

19 'stimCenterTrigger_times', 'stimCenter_times', 'position', 'phase') 

20 

21 def __init__(self, *args, **kwargs): 

22 super().__init__(*args, **kwargs) 1ba

23 exclude = ['itiIn_times', 'stimCenter_times', 'stimCenterTrigger_times', 'position', 'phase'] 1ba

24 self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names) 1ba

25 

26 def _extract(self) -> dict: 

27 """ 

28 Extract the Bpod trial events. 

29 

30 For iblrig versions < 8.13 the Bpod state machine for this task had extremely misleading names! 

31 The 'iti' state was actually the delay between valve close and trial end (the stimulus is 

32 still present during this period), and the 'trial_start' state is actually the ITI during 

33 which there is a 1s Bpod TTL and gray screen period. 

34 

35 In version 8.13 and later, the 'iti' state was renamed to 'post_reward' and 'trial_start' 

36 was renamed to 'iti'. 

37 

38 Returns 

39 ------- 

40 dict 

41 A dictionary of Bpod trial events. The keys are defined in the `var_names` attribute. 

42 """ 

43 # Extract all trials... 

44 

45 # Get all detected TTLs. These are stored for QC purposes 

46 self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) 1ba

47 # These are the frame2TTL pulses as a list of lists, one per trial 

48 ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials] 1ba

49 

50 # Report missing events 

51 n_missing = sum(len(pulses) != 3 for pulses in ttls) 1ba

52 # Check if all stim syncs have failed to be detected 

53 if n_missing == len(ttls): 1ba

54 _logger.error(f'{self.session_path}: Missing ALL BNC1 TTLs ({n_missing} trials)') 

55 elif n_missing > 0: # Check if any stim_sync has failed be detected for every trial 1ba

56 _logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)') 1a

57 

58 # Extract datasets common to trainingChoiceWorld 

59 training = [ContrastLR, FeedbackTimes, GoCueTimes, StimOnTriggerTimes] 1ba

60 out, _ = run_extractor_classes(training, session_path=self.session_path, save=False, 1ba

61 bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) 

62 

63 """ 1ba

64 The 'trial_start'/'iti' state is in fact the 1s grey screen period, therefore the first 

65 timestamp is really the end of the previous trial and also the stimOff trigger time. The 

66 second timestamp is the true trial start time. This state was renamed in version 8.13. 

67 """ 

68 state_names = self.bpod_trials[0]['behavior_data']['States timestamps'].keys() 1ba

69 rig_version = version.parse(self.settings['IBLRIG_VERSION']) 1ba

70 legacy_state_machine = 'post_reward' not in state_names and 'trial_start' in state_names 1ba

71 

72 key = 'iti' if (rig_version >= version.parse('8.13') and not legacy_state_machine) else 'trial_start' 1ba

73 (_, *ends), starts = zip(*[ 1ba

74 t['behavior_data']['States timestamps'][key][-1] for t in self.bpod_trials] 

75 ) 

76 

77 # StimOffTrigger times 

78 out['stimOffTrigger_times'] = np.array(ends) 1ba

79 

80 # StimOff times 

81 """ 1ba

82 There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse. 

83 If 1 or more pulses are missing, we can not be confident of assigning the correct one. 

84 """ 

85 out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan for sync in ttls[1:]]) 1ba

86 

87 # Trial intervals 

88 """ 1ba

89 In terms of TTLs, the intervals are defined by the 'trial_start' state, however the stim 

90 off time often happens after the trial end TTL front, i.e. after the 'trial_start' start 

91 begins. For these trials, we set the trial end time as the stim off time. 

92 """ 

93 # NB: We lose the last trial because the stim off event occurs at trial_num + 1 

94 n_trials = out['stimOff_times'].size 1ba

95 out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :] 1ba

96 

97 to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1]) 1ba

98 if np.any(to_correct): 1ba

99 _logger.debug( 1ba

100 '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end', 

101 sum(to_correct), len(to_correct)) 

102 out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct] 1ba

103 

104 # itiIn times 

105 out['itiIn_times'] = np.r_[ends, np.nan] 1ba

106 

107 # GoCueTriggerTimes is the same event as StimOnTriggerTimes 

108 out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy() 1ba

109 

110 # StimCenterTrigger times 

111 # Get the stim_on_state that triggers the onset of the stim 

112 stim_center_state = np.array([tr['behavior_data']['States timestamps'] 1ba

113 ['stim_center'][0] for tr in self.bpod_trials]) 

114 out['stimCenterTrigger_times'] = stim_center_state[:, 0].T 1ba

115 

116 # StimCenter times 

117 stim_center_times = np.full(out['stimCenterTrigger_times'].shape, np.nan) 1ba

118 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1ba

119 """We expect there to be 3 pulses per trial; if this is the case, stim center will 1ba

120 be the third pulse. If any pulses are missing, we can only be confident of the correct 

121 one if exactly one pulse occurs after the stim center trigger""" 

122 if len(sync) == 3 or (len(sync) > 0 and sum(pulse > last for pulse in sync) == 1): 1ba

123 stim_center_times[i] = sync[-1] 1ba

124 out['stimCenter_times'] = stim_center_times 1ba

125 

126 # StimOn times 

127 stimOn_times = np.full(out['stimOnTrigger_times'].shape, np.nan) 1ba

128 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1ba

129 """We expect there to be 3 pulses per trial; if this is the case, stim on will be the 1ba

130 second pulse. If 1 pulse is missing, we can only be confident of the correct one if 

131 both pulses occur before the stim center trigger""" 

132 if len(sync) == 3 or (len(sync) == 2 and sum(pulse < last for pulse in sync) == 2): 1ba

133 stimOn_times[i] = sync[1] 1ba

134 out['stimOn_times'] = stimOn_times 1ba

135 

136 # RewardVolume 

137 trial_volume = [x['reward_amount'] for x in self.bpod_trials] 1ba

138 out['rewardVolume'] = np.array(trial_volume).astype(np.float64) 1ba

139 

140 # FeedbackType is always positive 

141 out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8) 1ba

142 

143 # Phase and position 

144 out['position'] = np.array([t['position'] for t in self.bpod_trials]) 1ba

145 out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials]) 1ba

146 

147 # Double-check that the early and late trial events occur within the trial intervals 

148 idx = ~np.isnan(out['stimOn_times'][:n_trials]) 1ba

149 assert not np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]), \ 1ba

150 'Stim on events occurring outside trial intervals' 

151 

152 # Truncate arrays and return in correct order 

153 return {k: out[k][:n_trials] for k in self.var_names} 1ba