Coverage for ibllib/io/extractors/habituation_trials.py: 98%

64 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-07 14:26 +0100

1"""Habituation ChoiceWorld Bpod trials extraction.""" 

2import logging 

3import numpy as np 

4 

5from packaging import version 

6 

7import ibllib.io.raw_data_loaders as raw 

8from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

9from ibllib.io.extractors.biased_trials import ContrastLR 

10from ibllib.io.extractors.training_trials import FeedbackTimes, StimOnTriggerTimes, GoCueTimes 

11 

12_logger = logging.getLogger(__name__) 

13 

14 

15class HabituationTrials(BaseBpodTrialsExtractor): 

16 var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight', 

17 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals', 

18 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 

19 'stimCenterTrigger_times', 'stimCenter_times', 'position', 'phase') 

20 

21 def __init__(self, *args, **kwargs): 

22 super().__init__(*args, **kwargs) 1cab

23 exclude = ['itiIn_times', 'stimCenter_times', 'stimCenterTrigger_times', 'position', 'phase'] 1cab

24 self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names) 1cab

25 

26 def _extract(self) -> dict: 

27 """ 

28 Extract the Bpod trial events. 

29 

30 For iblrig versions < 8.13 the Bpod state machine for this task had extremely misleading names! 

31 The 'iti' state was actually the delay between valve close and trial end (the stimulus is 

32 still present during this period), and the 'trial_start' state is actually the ITI during 

33 which there is a 1s Bpod TTL and gray screen period. 

34 

35 In version 8.13 and later, the 'iti' state was renamed to 'post_reward' and 'trial_start' 

36 was renamed to 'iti'. 

37 

38 Returns 

39 ------- 

40 dict 

41 A dictionary of Bpod trial events. The keys are defined in the `var_names` attribute. 

42 """ 

43 # Extract all trials... 

44 

45 # Get all detected TTLs. These are stored for QC purposes 

46 self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) 1cab

47 # These are the frame2TTL pulses as a list of lists, one per trial 

48 ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials] 1cab

49 

50 # Report missing events 

51 n_missing = sum(len(pulses) != 3 for pulses in ttls) 1cab

52 # Check if all stim syncs have failed to be detected 

53 if n_missing == len(ttls): 1cab

54 _logger.error(f'{self.session_path}: Missing ALL BNC1 TTLs ({n_missing} trials)') 

55 elif n_missing > 0: # Check if any stim_sync has failed be detected for every trial 1cab

56 _logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)') 1ab

57 

58 # Extract datasets common to trainingChoiceWorld 

59 training = [ContrastLR, FeedbackTimes, GoCueTimes, StimOnTriggerTimes] 1cab

60 out, _ = run_extractor_classes(training, session_path=self.session_path, save=False, 1cab

61 bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) 

62 

63 """ 1cab

64 The 'trial_start'/'iti' state is in fact the 1s grey screen period, therefore the first 

65 timestamp is really the end of the previous trial and also the stimOff trigger time. The 

66 second timestamp is the true trial start time. This state was renamed in version 8.13. 

67 """ 

68 state_names = self.bpod_trials[0]['behavior_data']['States timestamps'].keys() 1cab

69 rig_version = version.parse(self.settings['IBLRIG_VERSION']) 1cab

70 legacy_state_machine = 'post_reward' not in state_names and 'trial_start' in state_names 1cab

71 

72 key = 'iti' if (rig_version >= version.parse('8.13') and not legacy_state_machine) else 'trial_start' 1cab

73 (_, *ends), starts = zip(*[ 1cab

74 t['behavior_data']['States timestamps'][key][-1] for t in self.bpod_trials] 

75 ) 

76 

77 # StimOffTrigger times 

78 out['stimOffTrigger_times'] = np.array(ends) 1cab

79 

80 # StimOff times 

81 """ 1cab

82 There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse. 

83 If 1 or more pulses are missing, we can not be confident of assigning the correct one. 

84 """ 

85 out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan for sync in ttls[1:]]) 1cab

86 

87 # Trial intervals 

88 """ 1cab

89 In terms of TTLs, the intervals are defined by the 'trial_start' state, however the stim 

90 off time often happens after the trial end TTL front, i.e. after the 'trial_start' start 

91 begins. For these trials, we set the trial end time as the stim off time. 

92 """ 

93 # NB: We lose the last trial because the stim off event occurs at trial_num + 1 

94 n_trials = out['stimOff_times'].size 1cab

95 out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :] 1cab

96 

97 to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1]) 1cab

98 if np.any(to_correct): 1cab

99 _logger.debug( 1cab

100 '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end', 

101 sum(to_correct), len(to_correct)) 

102 out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct] 1cab

103 

104 # itiIn times 

105 out['itiIn_times'] = np.r_[ends, np.nan] 1cab

106 

107 # GoCueTriggerTimes is the same event as StimOnTriggerTimes 

108 out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy() 1cab

109 

110 # StimCenterTrigger times 

111 # Get the stim_on_state that triggers the onset of the stim 

112 stim_center_state = np.array([tr['behavior_data']['States timestamps'] 1cab

113 ['stim_center'][0] for tr in self.bpod_trials]) 

114 out['stimCenterTrigger_times'] = stim_center_state[:, 0].T 1cab

115 

116 # StimCenter times 

117 stim_center_times = np.full(out['stimCenterTrigger_times'].shape, np.nan) 1cab

118 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1cab

119 """We expect there to be 3 pulses per trial; if this is the case, stim center will 1cab

120 be the third pulse. If any pulses are missing, we can only be confident of the correct 

121 one if exactly one pulse occurs after the stim center trigger""" 

122 if len(sync) == 3 or (len(sync) > 0 and sum(pulse > last for pulse in sync) == 1): 1cab

123 stim_center_times[i] = sync[-1] 1cab

124 out['stimCenter_times'] = stim_center_times 1cab

125 

126 # StimOn times 

127 stimOn_times = np.full(out['stimOnTrigger_times'].shape, np.nan) 1cab

128 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1cab

129 """We expect there to be 3 pulses per trial; if this is the case, stim on will be the 1cab

130 second pulse. If 1 pulse is missing, we can only be confident of the correct one if 

131 both pulses occur before the stim center trigger""" 

132 if len(sync) == 3 or (len(sync) == 2 and sum(pulse < last for pulse in sync) == 2): 1cab

133 stimOn_times[i] = sync[1] 1cab

134 out['stimOn_times'] = stimOn_times 1cab

135 

136 # RewardVolume 

137 trial_volume = [x['reward_amount'] for x in self.bpod_trials] 1cab

138 out['rewardVolume'] = np.array(trial_volume).astype(np.float64) 1cab

139 

140 # FeedbackType is always positive 

141 out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8) 1cab

142 

143 # Phase and position 

144 out['position'] = np.array([t['position'] for t in self.bpod_trials]) 1cab

145 out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials]) 1cab

146 

147 # Double-check that the early and late trial events occur within the trial intervals 

148 idx = ~np.isnan(out['stimOn_times'][:n_trials]) 1cab

149 assert not np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]), \ 1cab

150 'Stim on events occurring outside trial intervals' 

151 

152 # Truncate arrays and return in correct order 

153 return {k: out[k][:n_trials] for k in self.var_names} 1cab