Coverage for brainbox/io/spikeglx.py: 74%

124 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1import shutil 

2import logging 

3from pathlib import Path 

4import time 

5import json 

6import string 

7import random 

8 

9import numpy as np 

10from one.alf.files import remove_uuid_string 

11 

12import spikeglx 

13 

14_logger = logging.getLogger('ibllib') 

15 

16 

17def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, car=True): 

18 """ 

19 Extracts spike waveforms from binary ephys data file, after (optionally) 

20 common-average-referencing (CAR) spatial noise. 

21 

22 Parameters 

23 ---------- 

24 ephys_file : string 

25 The file path to the binary ephys data. 

26 ts : ndarray_like 

27 The timestamps (in s) of the spikes. 

28 ch : ndarray_like 

29 The channels on which to extract the waveforms. 

30 t : numeric (optional) 

31 The time (in ms) of each returned waveform. 

32 sr : int (optional) 

33 The sampling rate (in hz) that the ephys data was acquired at. 

34 n_ch_probe : int (optional) 

35 The number of channels of the recording. 

36 car: bool (optional) 

37 A flag to perform CAR before extracting waveforms. 

38 

39 Returns 

40 ------- 

41 waveforms : ndarray 

42 An array of shape (#spikes, #samples, #channels) containing the waveforms. 

43 

44 Examples 

45 -------- 

46 1) Extract all the waveforms for unit1 with and without CAR. 

47 >>> import numpy as np 

48 >>> import brainbox as bb 

49 >>> import one.alf.io as alfio 

50 >>> import ibllib.ephys.spikes as e_spks 

51 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): 

52 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) 

53 # Get a clusters bunch and a units bunch from a spikes bunch from an alf directory. 

54 >>> clstrs_b = alfio.load_object(path_to_alf_out, 'clusters') 

55 >>> spks_b = alfio.load_object(path_to_alf_out, 'spikes') 

56 >>> units_b = bb.processing.get_units_bunch(spks, ['times']) 

57 # Get the timestamps and 20 channels around the max amp channel for unit1, and extract the 

58 # two sets of waveforms. 

59 >>> ts = units_b['times']['1'] 

60 >>> max_ch = max_ch = clstrs_b['channels'][1] 

61 >>> if max_ch < 10: # take only channels greater than `max_ch`. 

62 >>> ch = np.arange(max_ch, max_ch + 20) 

63 >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`. 

64 >>> ch = np.arange(max_ch - 20, max_ch) 

65 >>> else: # take `n_c_ch` around `max_ch`. 

66 >>> ch = np.arange(max_ch - 10, max_ch + 10) 

67 >>> wf = bb.io.extract_waveforms(path_to_ephys_file, ts, ch, car=False) 

68 >>> wf_car = bb.io.extract_waveforms(path_to_ephys_file, ts, ch, car=True) 

69 """ 

70 

71 # Get memmapped array of `ephys_file` 

72 with spikeglx.Reader(ephys_file) as s_reader: 

73 file_m = s_reader.data # the memmapped array 

74 n_wf_samples = int(sr / 1000 * (t / 2)) # number of samples to return on each side of a ts 

75 ts_samples = np.array(ts * sr).astype(int) # the samples corresponding to `ts` 

76 t_sample_first = ts_samples[0] - n_wf_samples 

77 

78 # Exception handling for impossible channels 

79 ch = np.asarray(ch) 

80 ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch 

81 if np.any(ch < 0) or np.any(ch > n_ch_probe): 

82 raise Exception('At least one specified channel number is impossible. ' 

83 f'The minimum channel number was {np.min(ch)}, ' 

84 f'and the maximum channel number was {np.max(ch)}. ' 

85 'Check specified channel numbers and try again.') 

86 

87 if car: # compute spatial noise in chunks 

88 # see https://github.com/int-brain-lab/iblenv/issues/5 

89 raise NotImplementedError("CAR option is not available") 

90 

91 # Initialize `waveforms`, extract waveforms from `file_m`, and CAR. 

92 waveforms = np.zeros((len(ts), 2 * n_wf_samples, ch.size)) 

93 # Give time estimate for extracting waveforms. 

94 t0 = time.perf_counter() 

95 for i in range(5): 

96 waveforms[i, :, :] = \ 

97 file_m[i * n_wf_samples * 2 + t_sample_first: 

98 i * n_wf_samples * 2 + t_sample_first + n_wf_samples * 2, ch].reshape( 

99 (n_wf_samples * 2, ch.size)) 

100 dt = time.perf_counter() - t0 

101 print('Performing waveform extraction. Estimated time is {:.2f} mins. ({})' 

102 .format(dt * len(ts) / 60 / 5, time.ctime())) 

103 for spk, _ in enumerate(ts): # extract waveforms 

104 spk_ts_sample = ts_samples[spk] 

105 spk_samples = np.arange(spk_ts_sample - n_wf_samples, spk_ts_sample + n_wf_samples) 

106 # have to reshape to add an axis to broadcast `file_m` into `waveforms` 

107 waveforms[spk, :, :] = \ 

108 file_m[spk_samples[0]:spk_samples[-1] + 1, ch].reshape((spk_samples.size, ch.size)) 

109 print('Done. ({})'.format(time.ctime())) 

110 

111 return waveforms 

112 

113 

114class Streamer(spikeglx.Reader): 

115 """ 

116 pid = 'e31b4e39-e350-47a9-aca4-72496d99ff2a' 

117 one = ONE() 

118 sr = Streamer(pid=pid, one=one) 

119 raw_voltage = sr[int(t0 * sr.fs):int((t0 + nsecs) * sr.fs), :] 

120 """ 

121 def __init__(self, pid, one, typ='ap', cache_folder=None, remove_cached=False): 

122 self.target_dir = None # last chunk directory download or read 1a

123 self.one = one 1a

124 self.pid = pid 1a

125 self.cache_folder = cache_folder or Path(self.one.alyx._par.CACHE_DIR).joinpath('cache', typ) 1a

126 self.remove_cached = remove_cached 1a

127 self.eid, self.pname = self.one.pid2eid(pid) 1a

128 self.file_chunks = self.one.load_dataset(self.eid, f'*.{typ}.ch', collection=f"*{self.pname}") 1a

129 meta_file = self.one.load_dataset(self.eid, f'*.{typ}.meta', collection=f"*{self.pname}") 1a

130 cbin_rec = self.one.list_datasets(self.eid, collection=f"*{self.pname}", filename=f'*{typ}.*bin', details=True) 1a

131 self.url_cbin = self.one.record2url(cbin_rec)[0] 1a

132 with open(self.file_chunks, 'r') as f: 1a

133 self.chunks = json.load(f) 1a

134 self.chunks['chunk_bounds'] = np.array(self.chunks['chunk_bounds']) 1a

135 super(Streamer, self).__init__(meta_file, ignore_warnings=True) 1a

136 

137 def read(self, nsel=slice(0, 10000), csel=slice(None), sync=True, volts=True): 

138 """ 

139 overload the read function by downloading the necessary chunks 

140 """ 

141 first_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.start) - 1) 1a

142 last_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.stop) - 1) 1a

143 n0 = self.chunks['chunk_bounds'][first_chunk] 1a

144 _logger.debug(f'Streamer: caching sample {n0}, (t={n0 / self.fs})') 1a

145 self.cache_folder.mkdir(exist_ok=True, parents=True) 1a

146 sr, file_cbin = self._download_raw_partial(first_chunk=first_chunk, last_chunk=last_chunk) 1a

147 if not volts: 1a

148 data = np.copy(sr._raw[nsel.start - n0:nsel.stop - n0, csel]) 

149 else: 

150 data = sr[nsel.start - n0: nsel.stop - n0, csel] 1a

151 sr.close() 1a

152 if self.remove_cached: 1a

153 shutil.rmtree(self.target_dir) 

154 return data 1a

155 

156 def _download_raw_partial(self, first_chunk=0, last_chunk=0): 

157 """ 

158 downloads one or several chunks of a mtscomp file and copy ch files and metadata to return 

159 a spikeglx.Reader object 

160 :param first_chunk: 

161 :param last_chunk: 

162 :return: spikeglx.Reader of the current chunk, Pathlib.Path of the directory where it is stored 

163 :return: cbin local path 

164 """ 

165 assert str(self.url_cbin).endswith('.cbin') 1a

166 webclient = self.one.alyx 1a

167 relpath = Path(self.url_cbin.replace(webclient._par.HTTP_DATA_SERVER, '.')).parents[0] 1a

168 # write the temp file into a subdirectory 

169 tdir_chunk = f"chunk_{str(first_chunk).zfill(6)}_to_{str(last_chunk).zfill(6)}" 1a

170 # for parallel processes, there is a risk of collisions if the removed cached flag is set to True 

171 # if the folder is to be removed append a unique identifier to avoid having duplicate names 

172 if self.remove_cached: 1a

173 tdir_chunk += ''.join([random.choice(string.ascii_letters) for _ in np.arange(10)]) 

174 self.target_dir = Path(self.cache_folder, relpath, tdir_chunk) 1a

175 Path(self.target_dir).mkdir(parents=True, exist_ok=True) 1a

176 ch_file_stream = self.target_dir.joinpath(self.file_chunks.name).with_suffix('.stream.ch') 1a

177 

178 # Get the first sample index, and the number of samples to download. 

179 i0 = self.chunks['chunk_bounds'][first_chunk] 1a

180 ns_stream = self.chunks['chunk_bounds'][last_chunk + 1] - i0 1a

181 total_samples = self.chunks['chunk_bounds'][-1] 1a

182 

183 # handles the meta file 

184 meta_local_path = ch_file_stream.with_suffix('.meta') 1a

185 if not meta_local_path.exists(): 1a

186 shutil.copy(self.file_chunks.with_suffix('.meta'), meta_local_path) 1a

187 

188 # if the cached version happens to be the same as the one on disk, just load it 

189 if ch_file_stream.exists() and ch_file_stream.with_suffix('.cbin').exists(): 1a

190 with open(ch_file_stream, 'r') as f: 1a

191 cmeta_stream = json.load(f) 1a

192 if (cmeta_stream.get('chopped_first_sample', None) == i0 and 1a

193 cmeta_stream.get('chopped_total_samples', None) == total_samples): 

194 return spikeglx.Reader(ch_file_stream.with_suffix('.cbin'), ignore_warnings=True), ch_file_stream 1a

195 

196 else: 

197 shutil.copy(self.file_chunks, ch_file_stream) 1a

198 assert ch_file_stream.exists() 1a

199 

200 cmeta = self.chunks.copy() 1a

201 # prepare the metadata file 

202 cmeta['chunk_bounds'] = cmeta['chunk_bounds'][first_chunk:last_chunk + 2] 1a

203 cmeta['chunk_bounds'] = [int(_ - i0) for _ in cmeta['chunk_bounds']] 1a

204 assert len(cmeta['chunk_bounds']) >= 2 1a

205 assert cmeta['chunk_bounds'][0] == 0 1a

206 

207 first_byte = cmeta['chunk_offsets'][first_chunk] 1a

208 cmeta['chunk_offsets'] = cmeta['chunk_offsets'][first_chunk:last_chunk + 2] 1a

209 cmeta['chunk_offsets'] = [_ - first_byte for _ in cmeta['chunk_offsets']] 1a

210 assert len(cmeta['chunk_offsets']) >= 2 1a

211 assert cmeta['chunk_offsets'][0] == 0 1a

212 n_bytes = cmeta['chunk_offsets'][-1] 1a

213 assert n_bytes > 0 1a

214 

215 # Save the chopped chunk bounds and offsets. 

216 cmeta['sha1_compressed'] = None 1a

217 cmeta['sha1_uncompressed'] = None 1a

218 cmeta['chopped'] = True 1a

219 cmeta['chopped_first_sample'] = int(i0) 1a

220 cmeta['chopped_samples'] = int(ns_stream) 1a

221 cmeta['chopped_total_samples'] = int(total_samples) 1a

222 

223 with open(ch_file_stream, 'w') as f: 1a

224 json.dump(cmeta, f, indent=2, sort_keys=True) 1a

225 

226 # Download the requested chunks 

227 retries = 0 1a

228 while True: 1a

229 try: 1a

230 cbin_local_path = webclient.download_file( 1a

231 self.url_cbin, chunks=(first_byte, n_bytes), 

232 target_dir=self.target_dir, clobber=True, return_md5=False) 

233 break 1a

234 except Exception as e: 

235 retries += 1 

236 if retries > 5: 

237 raise e 

238 _logger.warning(f'Failed to download chunk {first_chunk} to {last_chunk}, retrying') 

239 time.sleep(1) 

240 cbin_local_path_renamed = remove_uuid_string(cbin_local_path).with_suffix('.stream.cbin') 1a

241 cbin_local_path.replace(cbin_local_path_renamed) 1a

242 assert cbin_local_path_renamed.exists() 1a

243 

244 reader = spikeglx.Reader(cbin_local_path_renamed, ignore_warnings=True) 1a

245 return reader, cbin_local_path 1a