Coverage for ibllib/ephys/spikes.py: 64%

120 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1from pathlib import Path 

2import logging 

3import json 

4import shutil 

5import tarfile 

6 

7import numpy as np 

8from one.alf.files import get_session_path 

9import spikeglx 

10 

11from iblutil.util import Bunch 

12import phylib.io.alf 

13from ibllib.ephys.sync_probes import apply_sync 

14import ibllib.ephys.ephysqc as ephysqc 

15from ibllib.ephys import sync_probes 

16 

17_logger = logging.getLogger(__name__) 

18 

19 

20def probes_description(ses_path, one): 

21 """ 

22 Aggregate probes information into ALF files 

23 Register alyx probe insertions and Micro-manipulator trajectories 

24 Input: 

25 raw_ephys_data/probeXX/ 

26 Output: 

27 alf/probes.description.npy 

28 """ 

29 

30 eid = one.path2eid(ses_path, query_type='remote') 1a

31 ses_path = Path(ses_path) 1a

32 meta_files = spikeglx.glob_ephys_files(ses_path, ext='meta') 1a

33 ap_meta_files = [(ep.ap.parent, ep.label, ep) for ep in meta_files if ep.get('ap')] 1a

34 # If we don't detect any meta files exit function 

35 if len(ap_meta_files) == 0: 1a

36 return 

37 

38 subdirs, labels, efiles_sorted = zip(*sorted(ap_meta_files)) 1a

39 

40 def _create_insertion(md, label, eid): 1a

41 

42 # create json description 

43 description = {'label': label, 'model': md.neuropixelVersion, 'serial': int(md.serial), 'raw_file_name': md.fileName} 1a

44 

45 # create or update probe insertion on alyx 

46 alyx_insertion = {'session': eid, 'model': md.neuropixelVersion, 'serial': md.serial, 'name': label} 1a

47 pi = one.alyx.rest('insertions', 'list', session=eid, name=label) 1a

48 if len(pi) == 0: 1a

49 qc_dict = {'qc': 'NOT_SET', 'extended_qc': {}} 1a

50 alyx_insertion.update({'json': qc_dict}) 1a

51 insertion = one.alyx.rest('insertions', 'create', data=alyx_insertion) 1a

52 else: 

53 insertion = one.alyx.rest('insertions', 'partial_update', data=alyx_insertion, id=pi[0]['id']) 1a

54 

55 return description, insertion 1a

56 

57 # Ouputs the probes description file 

58 probe_description = [] 1a

59 alyx_insertions = [] 1a

60 for label, ef in zip(labels, efiles_sorted): 1a

61 md = spikeglx.read_meta_data(ef.ap.with_suffix('.meta')) 1a

62 if md.neuropixelVersion == 'NP2.4': 1a

63 # NP2.4 meta that hasn't been split 

64 if md.get('NP2.4_shank', None) is None: 1a

65 geometry = spikeglx.read_geometry(ef.ap.with_suffix('.meta')) 1a

66 nshanks = np.unique(geometry['shank']) 1a

67 for shank in nshanks: 1a

68 label_ext = f'{label}{chr(97 + int(shank))}' 1a

69 description, insertion = _create_insertion(md, label_ext, eid) 1a

70 probe_description.append(description) 1a

71 alyx_insertions.append(insertion) 1a

72 # NP2.4 meta that has already been split 

73 else: 

74 description, insertion = _create_insertion(md, label, eid) 

75 probe_description.append(description) 

76 alyx_insertions.append(insertion) 

77 else: 

78 description, insertion = _create_insertion(md, label, eid) 1a

79 probe_description.append(description) 1a

80 alyx_insertions.append(insertion) 1a

81 

82 alf_path = ses_path.joinpath('alf') 1a

83 alf_path.mkdir(exist_ok=True, parents=True) 1a

84 probe_description_file = alf_path.joinpath('probes.description.json') 1a

85 with open(probe_description_file, 'w+') as fid: 1a

86 fid.write(json.dumps(probe_description)) 1a

87 

88 return [probe_description_file] 1a

89 

90 

91def sync_spike_sorting(ap_file, out_path): 

92 """ 

93 Synchronizes the spike.times using the previously computed sync files 

94 :param ap_file: raw binary data file for the probe insertion 

95 :param out_path: probe output path (usually {session_path}/alf/{probe_label}) 

96 """ 

97 

98 def _sr(ap_file): 

99 # gets sampling rate from data 

100 md = spikeglx.read_meta_data(ap_file.with_suffix('.meta')) 

101 return spikeglx._get_fs_from_meta(md) 

102 

103 out_files = [] 

104 label = ap_file.parts[-1] # now the bin file is always in a folder bearing the name of probe 

105 sync_file = ap_file.parent.joinpath( 

106 ap_file.name.replace('.ap.', '.sync.')).with_suffix('.npy') 

107 # try to get probe sync if it doesn't exist 

108 if not sync_file.exists(): 

109 _, sync_files = sync_probes.sync(get_session_path(ap_file)) 

110 out_files.extend(sync_files) 

111 # if it still not there, full blown error 

112 if not sync_file.exists(): 

113 # if there is no sync file it means something went wrong. Outputs the spike sorting 

114 # in time according the the probe by following ALF convention on the times objects 

115 error_msg = f'No synchronisation file for {label}: {sync_file}. The spike-' \ 

116 f'sorting is not synchronized and data not uploaded on Flat-Iron' 

117 _logger.error(error_msg) 

118 # remove the alf folder if the sync failed 

119 shutil.rmtree(out_path) 

120 return None, 1 

121 # patch the spikes.times files manually 

122 st_file = out_path.joinpath('spikes.times.npy') 

123 spike_samples = np.load(out_path.joinpath('spikes.samples.npy')) 

124 interp_times = apply_sync(sync_file, spike_samples / _sr(ap_file), forward=True) 

125 np.save(st_file, interp_times) 

126 # get the list of output files 

127 out_files.extend([f for f in out_path.glob("*.*") if 

128 f.name.startswith(('channels.', 'drift', 'clusters.', 'spikes.', 'templates.', 

129 '_kilosort_', '_phy_spikes_subset', '_ibl_log.info'))]) 

130 # the QC files computed during spike sorting stay within the raw ephys data folder 

131 out_files.extend(list(ap_file.parent.glob('_iblqc_*AP.*.npy'))) 

132 return out_files, 0 

133 

134 

135def ks2_to_alf(ks_path, bin_path, out_path, bin_file=None, ampfactor=1, label=None, force=True): 

136 """ 

137 Convert Kilosort 2 output to ALF dataset for single probe data 

138 :param ks_path: 

139 :param bin_path: path of raw data 

140 :param out_path: 

141 :return: 

142 """ 

143 m = ephysqc.phy_model_from_ks2_path(ks2_path=ks_path, bin_path=bin_path, bin_file=bin_file) 

144 ac = phylib.io.alf.EphysAlfCreator(m) 

145 ac.convert(out_path, label=label, force=force, ampfactor=float(ampfactor)) 

146 

147 

148def ks2_to_tar(ks_path, out_path, force=False): 

149 """ 

150 Compress output from kilosort 2 into tar file in order to register to flatiron and move to 

151 spikesorters/ks2_matlab/probexx path. Output file to register 

152 

153 :param ks_path: path to kilosort output 

154 :param out_path: path to keep the 

155 :return 

156 path to tar ks output 

157 

158 To extract files from the tar file can use this code 

159 Example: 

160 save_path = Path('folder you want to extract to') 

161 with tarfile.open('_kilosort_output.tar', 'r') as tar_dir: 

162 tar_dir.extractall(path=save_path) 

163 

164 """ 

165 ks2_output = ['amplitudes.npy', 

166 'channel_map.npy', 

167 'channel_positions.npy', 

168 'cluster_Amplitude.tsv', 

169 'cluster_ContamPct.tsv', 

170 'cluster_group.tsv', 

171 'cluster_KSLabel.tsv', 

172 'params.py', 

173 'pc_feature_ind.npy', 

174 'pc_features.npy', 

175 'similar_templates.npy', 

176 'spike_clusters.npy', 

177 'spike_sorting_ks2.log', 

178 'spike_templates.npy', 

179 'spike_times.npy', 

180 'template_feature_ind.npy', 

181 'template_features.npy', 

182 'templates.npy', 

183 'templates_ind.npy', 

184 'whitening_mat.npy', 

185 'whitening_mat_inv.npy'] 

186 

187 out_file = Path(out_path).joinpath('_kilosort_raw.output.tar') 

188 if out_file.exists() and not force: 

189 _logger.info(f"Already converted ks2 to tar: for {ks_path}, skipping.") 

190 return [out_file] 

191 

192 with tarfile.open(out_file, 'w') as tar_dir: 

193 for file in Path(ks_path).iterdir(): 

194 if file.name in ks2_output: 

195 tar_dir.add(file, file.name) 

196 

197 return [out_file] 

198 

199 

200def detection(data, fs, h, detect_threshold=-4, time_tol=.002, distance_threshold_um=70): 

201 """ 

202 Detects and de-duplicates negative voltage spikes based on voltage thresholding. 

203 The de-duplication step locks in maximum amplitude events. To account for collisions the amplitude 

204 is assumed to be decaying from the peak. If this is a multipeak event, each is labeled as a spike. 

205 

206 :param data: 2D numpy array nsamples x nchannels 

207 :param fs: sampling frequency (Hz) 

208 :param h: dictionary with neuropixel geometry header: see. neuropixel.trace_header 

209 :param detect_threshold: negative value below which the voltage is considered to be a spike 

210 :param time_tol: time in seconds for which samples before and after are assumed to be part of the spike 

211 :param distance_threshold_um: distance for which exceeding threshold values are assumed to part of the same spike 

212 :return: spikes dictionary of vectors with keys "time", "trace", "amp" and "ispike" 

213 """ 

214 multipeak = False 1b

215 time_bracket = np.array([-1, 1]) * time_tol 1b

216 inds, indtr = np.where(data < detect_threshold) 1b

217 picks = Bunch(time=inds / fs, trace=indtr, amp=data[inds, indtr], ispike=np.zeros(inds.size)) 1b

218 amp_order = np.argsort(picks.amp) 1b

219 

220 hxy = h['x'] + 1j * h['y'] 1b

221 

222 spike_id = 1 1b

223 while np.any(picks.ispike == 0): 1b

224 # find the first unassigned spike with the highest amplitude 

225 iamp = np.where(picks.ispike[amp_order] == 0)[0][0] 1b

226 imax = amp_order[iamp] 1b

227 # look only within the time range 

228 itlims = np.searchsorted(picks.time, picks.time[imax] + time_bracket) 1b

229 itlims = np.arange(itlims[0], itlims[1]) 1b

230 

231 offset = np.abs(hxy[picks.trace[itlims]] - hxy[picks.trace[imax]]) 1b

232 iit = np.where(offset < distance_threshold_um)[0] 1b

233 

234 picks.ispike[itlims[iit]] = -1 1b

235 picks.ispike[imax] = spike_id 1b

236 # handles collision with a simple amplitude decay model: if amplitude doesn't decay 

237 # as a function of offset, then it's a collision and another spike is set 

238 if multipeak: # noqa 1b

239 iii = np.lexsort((picks.amp[itlims[iit]], offset[iit])) 

240 sorted_amps_db = 20 * np.log10(np.abs(picks.amp[itlims[iit][iii]])) 

241 idetect = np.r_[0, np.where(np.diff(sorted_amps_db) > 12)[0] + 1] 

242 picks.ispike[itlims[iit[iii[idetect]]]] = np.arange(idetect.size) + spike_id 

243 spike_id += idetect.size 

244 else: 

245 spike_id += 1 1b

246 

247 detects = Bunch({k: picks[k][picks.ispike > 0] for k in picks}) 1b

248 return detects 1b