Coverage for ibllib/io/extractors/habituation_trials.py: 92%

53 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1import logging 

2import numpy as np 

3 

4import ibllib.io.raw_data_loaders as raw 

5from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes 

6from ibllib.io.extractors.biased_trials import ContrastLR 

7from ibllib.io.extractors.training_trials import ( 

8 FeedbackTimes, StimOnTriggerTimes, Intervals, GoCueTimes 

9) 

10 

11_logger = logging.getLogger(__name__) 

12 

13 

14class HabituationTrials(BaseBpodTrialsExtractor): 

15 var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight', 

16 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals', 

17 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times', 

18 'stimCenterTrigger_times', 'stimCenter_times') 

19 

20 def __init__(self, *args, **kwargs): 

21 super().__init__(*args, **kwargs) 1a

22 exclude = ['itiIn_times', 'stimOffTrigger_times', 1a

23 'stimCenter_times', 'stimCenterTrigger_times'] 

24 self.save_names = tuple([f'_ibl_trials.{x}.npy' if x not in exclude else None 1a

25 for x in self.var_names]) 

26 

27 def _extract(self): 

28 # Extract all trials... 

29 

30 # Get all stim_sync events detected 

31 ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials] 1a

32 

33 # Report missing events 

34 n_missing = sum(len(pulses) != 3 for pulses in ttls) 1a

35 # Check if all stim syncs have failed to be detected 

36 if n_missing == len(ttls): 1a

37 _logger.error(f'{self.session_path}: Missing ALL BNC1 TTLs ({n_missing} trials)') 

38 elif n_missing > 0: # Check if any stim_sync has failed be detected for every trial 1a

39 _logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)') 

40 

41 # Extract datasets common to trainingChoiceWorld 

42 training = [ContrastLR, FeedbackTimes, Intervals, GoCueTimes, StimOnTriggerTimes] 1a

43 out, _ = run_extractor_classes(training, session_path=self.session_path, save=False, 1a

44 bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) 

45 

46 # GoCueTriggerTimes is the same event as StimOnTriggerTimes 

47 out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy() 1a

48 

49 # StimCenterTrigger times 

50 # Get the stim_on_state that triggers the onset of the stim 

51 stim_center_state = np.array([tr['behavior_data']['States timestamps'] 1a

52 ['stim_center'][0] for tr in self.bpod_trials]) 

53 out['stimCenterTrigger_times'] = stim_center_state[:, 0].T 1a

54 

55 # StimCenter times 

56 stim_center_times = np.full(out['stimCenterTrigger_times'].shape, np.nan) 1a

57 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1a

58 """We expect there to be 3 pulses per trial; if this is the case, stim center will 

59 be the third pulse. If any pulses are missing, we can only be confident of the correct 

60 one if exactly one pulse occurs after the stim center trigger""" 

61 if len(sync) == 3 or (len(sync) > 0 and sum(pulse > last for pulse in sync) == 1): 1a

62 stim_center_times[i] = sync[-1] 1a

63 out['stimCenter_times'] = stim_center_times 1a

64 

65 # StimOn times 

66 stimOn_times = np.full(out['stimOnTrigger_times'].shape, np.nan) 1a

67 for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])): 1a

68 """We expect there to be 3 pulses per trial; if this is the case, stim on will be the 

69 second pulse. If 1 pulse is missing, we can only be confident of the correct one if 

70 both pulses occur before the stim center trigger""" 

71 if len(sync) == 3 or (len(sync) == 2 and sum(pulse < last for pulse in sync) == 2): 1a

72 stimOn_times[i] = sync[1] 1a

73 out['stimOn_times'] = stimOn_times 1a

74 

75 # RewardVolume 

76 trial_volume = [x['reward_amount'] for x in self.bpod_trials] 1a

77 out['rewardVolume'] = np.array(trial_volume).astype(np.float64) 1a

78 

79 # StimOffTrigger times 

80 # StimOff occurs at trial start (ignore the first trial's state update) 

81 out['stimOffTrigger_times'] = np.array( 1a

82 [tr["behavior_data"]["States timestamps"] 

83 ["trial_start"][0][0] for tr in self.bpod_trials[1:]] 

84 ) 

85 

86 # StimOff times 

87 """ 1a

88 There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse. 

89 If 1 or more pulses are missing, we can not be confident of assigning the correct one. 

90 """ 

91 trigg = out['stimOffTrigger_times'] 1a

92 out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan 1a

93 for sync, off in zip(ttls[1:], trigg)]) 

94 

95 # FeedbackType is always positive 

96 out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8) 1a

97 

98 # ItiIn times 

99 out['itiIn_times'] = np.array( 1a

100 [tr["behavior_data"]["States timestamps"] 

101 ["iti"][0][0] for tr in self.bpod_trials] 

102 ) 

103 

104 # NB: We lose the last trial because the stim off event occurs at trial_num + 1 

105 n_trials = out['stimOff_times'].size 1a

106 return [out[k][:n_trials] for k in self.var_names] 1a

107 

108 

109def extract_all(session_path, save=False, bpod_trials=False, settings=False, task_collection='raw_behavior_data', save_path=None): 

110 """Extract all datasets from habituationChoiceWorld 

111 Note: only the datasets from the HabituationTrials extractor will be saved to disc. 

112 

113 :param session_path: The session path where the raw data are saved 

114 :param save: If True, the datasets that are considered standard are saved to the session path 

115 :param bpod_trials: The raw Bpod trial data 

116 :param settings: The raw Bpod sessions 

117 :returns: a dict of datasets and a corresponding list of file names 

118 """ 

119 if not bpod_trials: 1a

120 bpod_trials = raw.load_data(session_path, task_collection=task_collection) 

121 if not settings: 1a

122 settings = raw.load_settings(session_path, task_collection=task_collection) 

123 

124 # Standard datasets that may be saved as ALFs 

125 params = dict(session_path=session_path, bpod_trials=bpod_trials, settings=settings, task_collection=task_collection, 1a

126 path_out=save_path) 

127 out, fil = run_extractor_classes(HabituationTrials, save=save, **params) 1a

128 return out, fil 1a