Coverage for ibllib/io/extractors/bpod_trials.py: 85%

61 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1"""Trials data extraction from raw Bpod output 

2This module will extract the Bpod trials and wheel data based on the task protocol, 

3i.e. habituation, training or biased. 

4""" 

5import logging 

6import importlib 

7from collections import OrderedDict 

8import warnings 

9 

10from pkg_resources import parse_version 

11from ibllib.io.extractors import habituation_trials, training_trials, biased_trials, opto_trials 

12from ibllib.io.extractors.base import get_bpod_extractor_class, protocol2extractor 

13from ibllib.io.extractors.habituation_trials import HabituationTrials 

14from ibllib.io.extractors.training_trials import TrainingTrials 

15from ibllib.io.extractors.biased_trials import BiasedTrials, EphysTrials 

16from ibllib.io.extractors.base import get_session_extractor_type, BaseBpodTrialsExtractor 

17import ibllib.io.raw_data_loaders as rawio 

18 

19_logger = logging.getLogger(__name__) 

20 

21 

22def extract_all(session_path, save=True, bpod_trials=None, settings=None, 

23 task_collection='raw_behavior_data', extractor_type=None, save_path=None): 

24 """ 

25 Extracts a training session from its path. NB: Wheel must be extracted first in order to 

26 extract trials.firstMovement_times. 

27 

28 Parameters 

29 ---------- 

30 session_path : str, pathlib.Path 

31 The path to the session to be extracted. 

32 task_collection : str 

33 The subfolder containing the raw Bpod data files. 

34 save : bool 

35 If true, save the output files to save_path. 

36 bpod_trials : list of dict 

37 The loaded Bpod trial data. If None, attempts to load _iblrig_taskData.raw from 

38 raw_task_collection. 

39 settings : dict 

40 The loaded Bpod settings. If None, attempts to load _iblrig_taskSettings.raw from 

41 raw_task_collection. 

42 extractor_type : str 

43 The type of extraction. Supported types are {'ephys', 'biased', 'biased_opto', 

44 'ephys_biased_opto', 'training', 'ephys_training', 'habituation'}. If None, extractor type 

45 determined from settings. 

46 save_path : str, pathlib.Path 

47 The location of the output files if save is true. Defaults to <session_path>/alf. 

48 

49 Returns 

50 ------- 

51 dict 

52 The extracted trials data. 

53 dict 

54 The extracted wheel data. 

55 list of pathlib.Path 

56 The output files if save is true. 

57 """ 

58 warnings.warn('`extract_all` functions soon to be deprecated, use `bpod_trials.get_bpod_extractor` instead', FutureWarning) 1aescfdghijklmnobtpqr

59 if not extractor_type: 1aescfdghijklmnobtpqr

60 extractor_type = get_session_extractor_type(session_path, task_collection=task_collection) 1aescfdghijklmnobtpqr

61 _logger.info(f'Extracting {session_path} as {extractor_type}') 1aescfdghijklmnobtpqr

62 bpod_trials = bpod_trials or rawio.load_data(session_path, task_collection=task_collection) 1aescfdghijklmnobtpqr

63 settings = settings or rawio.load_settings(session_path, task_collection=task_collection) 1aescfdghijklmnobtpqr

64 _logger.info(f'{extractor_type} session on {settings["PYBPOD_BOARD"]}') 1aescfdghijklmnobtpqr

65 

66 # Determine which additional extractors are required 

67 extra = [] 1aescfdghijklmnobtpqr

68 if extractor_type == 'ephys': # Should exclude 'ephys_biased' 1aescfdghijklmnobtpqr

69 _logger.debug('Engaging biased TrialsTableEphys') 1afghijnopqr

70 extra.append(biased_trials.TrialsTableEphys) 1afghijnopqr

71 if extractor_type in ['biased_opto', 'ephys_biased_opto']: 1aescfdghijklmnobtpqr

72 _logger.debug('Engaging opto_trials LaserBool') 1klm

73 extra.append(opto_trials.LaserBool) 1klm

74 

75 # Determine base extraction 

76 if extractor_type in ['training', 'ephys_training']: 1aescfdghijklmnobtpqr

77 trials, files_trials = training_trials.extract_all(session_path, bpod_trials=bpod_trials, settings=settings, save=save, 1eb

78 task_collection=task_collection, save_path=save_path) 

79 # This is hacky but avoids extracting the wheel twice. 

80 # files_trials should contain wheel files at the end. 

81 files_wheel = [] 1eb

82 wheel = OrderedDict({k: trials.pop(k) for k in tuple(trials.keys()) if 'wheel' in k}) 1eb

83 elif 'biased' in extractor_type or 'ephys' in extractor_type: 1aescfdghijklmnobtpqr

84 trials, files_trials = biased_trials.extract_all( 1aesfdghijklmnobpqr

85 session_path, bpod_trials=bpod_trials, settings=settings, save=save, extra_classes=extra, 

86 task_collection=task_collection, save_path=save_path) 

87 

88 files_wheel = [] 1aesfdghijklmnobpqr

89 wheel = OrderedDict({k: trials.pop(k) for k in tuple(trials.keys()) if 'wheel' in k}) 1aesfdghijklmnobpqr

90 elif extractor_type == 'habituation': 1ct

91 if settings['IBLRIG_VERSION_TAG'] and \ 1ct

92 parse_version(settings['IBLRIG_VERSION_TAG']) <= parse_version('5.0.0'): 

93 _logger.warning('No extraction of legacy habituation sessions') 1t

94 return None, None, None 1t

95 trials, files_trials = habituation_trials.extract_all(session_path, bpod_trials=bpod_trials, settings=settings, save=save, 1c

96 task_collection=task_collection, save_path=save_path) 

97 wheel = None 1c

98 files_wheel = [] 1c

99 else: 

100 raise ValueError(f'No extractor for task {extractor_type}') 

101 _logger.info('session extracted \n') # timing info in log 1aescfdghijklmnobpqr

102 return trials, wheel, (files_trials + files_wheel) if save else None 1aescfdghijklmnobpqr

103 

104 

105def get_bpod_extractor(session_path, protocol=None, task_collection='raw_behavior_data') -> BaseBpodTrialsExtractor: 

106 """ 

107 Returns an extractor for a given session. 

108 

109 Parameters 

110 ---------- 

111 session_path : str, pathlib.Path 

112 The path to the session to be extracted. 

113 protocol : str, optional 

114 The protocol name, otherwise uses the PYBPOD_PROTOCOL key in iblrig task settings files. 

115 task_collection : str 

116 The folder within the session that contains the raw task data. 

117 

118 Returns 

119 ------- 

120 BaseBpodTrialsExtractor 

121 An instance of the task extractor class, instantiated with the session path. 

122 """ 

123 builtins = { 1ucvwdxb

124 'HabituationTrials': HabituationTrials, 

125 'TrainingTrials': TrainingTrials, 

126 'BiasedTrials': BiasedTrials, 

127 'EphysTrials': EphysTrials 

128 } 

129 if protocol: 1ucvwdxb

130 class_name = protocol2extractor(protocol) 1x

131 else: 

132 class_name = get_bpod_extractor_class(session_path, task_collection=task_collection) 1ucvwdb

133 if class_name in builtins: 1ucvwdxb

134 return builtins[class_name](session_path) 1ucvwdxb

135 

136 # look if there are custom extractor types in the personal projects repo 

137 if not class_name.startswith('projects.'): 

138 class_name = 'projects.' + class_name 

139 module, class_name = class_name.rsplit('.', 1) 

140 mdl = importlib.import_module(module) 

141 extractor_class = getattr(mdl, class_name, None) 

142 if extractor_class: 

143 return extractor_class(session_path) 

144 else: 

145 raise ValueError(f'extractor {class_name} not found')