Coverage for ibllib/io/extractors/bpod_trials.py: 80%

61 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""Trials data extraction from raw Bpod output. 

2 

3This module will extract the Bpod trials and wheel data based on the task protocol, 

4i.e. habituation, training or biased. 

5""" 

6import logging 

7import importlib 

8from collections import OrderedDict 

9import warnings 

10 

11from packaging import version 

12from ibllib.io.extractors import habituation_trials, training_trials, biased_trials, opto_trials 

13from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor 

14from ibllib.io.extractors.habituation_trials import HabituationTrials 

15from ibllib.io.extractors.training_trials import TrainingTrials 

16from ibllib.io.extractors.biased_trials import BiasedTrials, EphysTrials 

17from ibllib.io.extractors.base import get_session_extractor_type, BaseBpodTrialsExtractor 

18import ibllib.io.raw_data_loaders as rawio 

19 

20_logger = logging.getLogger(__name__) 

21 

22 

23def extract_all(session_path, save=True, bpod_trials=None, settings=None, 

24 task_collection='raw_behavior_data', extractor_type=None, save_path=None): 

25 """ 

26 Extracts a training session from its path. NB: Wheel must be extracted first in order to 

27 extract trials.firstMovement_times. 

28 

29 Parameters 

30 ---------- 

31 session_path : str, pathlib.Path 

32 The path to the session to be extracted. 

33 task_collection : str 

34 The subfolder containing the raw Bpod data files. 

35 save : bool 

36 If true, save the output files to save_path. 

37 bpod_trials : list of dict 

38 The loaded Bpod trial data. If None, attempts to load _iblrig_taskData.raw from 

39 raw_task_collection. 

40 settings : dict 

41 The loaded Bpod settings. If None, attempts to load _iblrig_taskSettings.raw from 

42 raw_task_collection. 

43 extractor_type : str 

44 The type of extraction. Supported types are {'ephys', 'biased', 'biased_opto', 

45 'ephys_biased_opto', 'training', 'ephys_training', 'habituation'}. If None, extractor type 

46 determined from settings. 

47 save_path : str, pathlib.Path 

48 The location of the output files if save is true. Defaults to <session_path>/alf. 

49 

50 Returns 

51 ------- 

52 dict 

53 The extracted trials data. 

54 dict 

55 The extracted wheel data. 

56 list of pathlib.Path 

57 The output files if save is true. 

58 """ 

59 warnings.warn('`extract_all` functions soon to be removed, use `bpod_trials.get_bpod_extractor` instead', FutureWarning) 1abicdefgjh

60 if not extractor_type: 1abicdefgjh

61 extractor_type = get_session_extractor_type(session_path, task_collection=task_collection) 1abicdefgjh

62 _logger.info(f'Extracting {session_path} as {extractor_type}') 1abicdefgjh

63 bpod_trials = bpod_trials or rawio.load_data(session_path, task_collection=task_collection) 1abicdefgjh

64 settings = settings or rawio.load_settings(session_path, task_collection=task_collection) 1abicdefgjh

65 _logger.info(f'{extractor_type} session on {settings["PYBPOD_BOARD"]}') 1abicdefgjh

66 

67 # Determine which additional extractors are required 

68 extra = [] 1abicdefgjh

69 if extractor_type == 'ephys': # Should exclude 'ephys_biased' 1abicdefgjh

70 _logger.debug('Engaging biased TrialsTableEphys') 1afgh

71 extra.append(biased_trials.TrialsTableEphys) 1afgh

72 if extractor_type in ['biased_opto', 'ephys_biased_opto']: 1abicdefgjh

73 _logger.debug('Engaging opto_trials LaserBool') 1cde

74 extra.append(opto_trials.LaserBool) 1cde

75 

76 # Determine base extraction 

77 if extractor_type in ['training', 'ephys_training']: 1abicdefgjh

78 trials, files_trials = training_trials.extract_all(session_path, bpod_trials=bpod_trials, settings=settings, save=save, 1b

79 task_collection=task_collection, save_path=save_path) 

80 # This is hacky but avoids extracting the wheel twice. 

81 # files_trials should contain wheel files at the end. 

82 files_wheel = [] 1b

83 wheel = OrderedDict({k: trials.pop(k) for k in tuple(trials.keys()) if 'wheel' in k}) 1b

84 elif 'biased' in extractor_type or 'ephys' in extractor_type: 1abicdefgjh

85 trials, files_trials = biased_trials.extract_all( 1abicdefgh

86 session_path, bpod_trials=bpod_trials, settings=settings, save=save, extra_classes=extra, 

87 task_collection=task_collection, save_path=save_path) 

88 

89 files_wheel = [] 1abicdefgh

90 wheel = OrderedDict({k: trials.pop(k) for k in tuple(trials.keys()) if 'wheel' in k}) 1abicdefgh

91 elif extractor_type == 'habituation': 1j

92 if settings['IBLRIG_VERSION'] and \ 1j

93 version.parse(settings['IBLRIG_VERSION']) <= version.parse('5.0.0'): 

94 _logger.warning('No extraction of legacy habituation sessions') 1j

95 return None, None, None 1j

96 trials, files_trials = habituation_trials.extract_all(session_path, bpod_trials=bpod_trials, settings=settings, save=save, 

97 task_collection=task_collection, save_path=save_path) 

98 wheel = None 

99 files_wheel = [] 

100 else: 

101 raise ValueError(f'No extractor for task {extractor_type}') 

102 _logger.info('session extracted \n') # timing info in log 1abicdefgh

103 return trials, wheel, (files_trials + files_wheel) if save else None 1abicdefgh

104 

105 

106def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavior_data') -> BaseBpodTrialsExtractor: 

107 """ 

108 Returns an extractor for a given session. 

109 

110 Parameters 

111 ---------- 

112 session_path : str, pathlib.Path 

113 The path to the session to be extracted. 

114 protocol : str, optional 

115 The protocol name, otherwise uses the PYBPOD_PROTOCOL key in iblrig task settings files. 

116 task_collection : str 

117 The folder within the session that contains the raw task data. 

118 

119 Returns 

120 ------- 

121 BaseBpodTrialsExtractor 

122 An instance of the task extractor class, instantiated with the session path. 

123 """ 

124 builtins = { 1lmnopqrkstuvwxy

125 'HabituationTrials': HabituationTrials, 

126 'TrainingTrials': TrainingTrials, 

127 'BiasedTrials': BiasedTrials, 

128 'EphysTrials': EphysTrials 

129 } 

130 if protocol: 1lmnopqrkstuvwxy

131 class_name = protocol2extractor(protocol) 1rk

132 else: 

133 class_name = get_bpod_extractor_class(session_path, task_collection=task_collection) 1lmnopqkstuvwxy

134 if class_name in builtins: 1lmnopqrkstuvwxy

135 return builtins[class_name](session_path) 1lmnopqrkstuvwxy

136 

137 # look if there are custom extractor types in the personal projects repo 

138 if not class_name.startswith('projects.'): 

139 class_name = 'projects.' + class_name 

140 module, class_name = class_name.rsplit('.', 1) 

141 mdl = importlib.import_module(module) 

142 extractor_class = getattr(mdl, class_name, None) 

143 if extractor_class: 

144 return extractor_class(session_path) 

145 else: 

146 raise ValueError(f'extractor {class_name} not found')