Coverage for brainbox/io/spikeglx.py: 74%
125 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1import shutil
2import logging
3from pathlib import Path
4import time
5import json
6import string
7import random
9import numpy as np
10from one.alf.path import remove_uuid_string
12import spikeglx
14_logger = logging.getLogger('ibllib')
17def extract_waveforms(ephys_file, ts, ch, t=2.0, sr=30000, n_ch_probe=385, car=True):
18 """
19 Extracts spike waveforms from binary ephys data file, after (optionally)
20 common-average-referencing (CAR) spatial noise.
22 Parameters
23 ----------
24 ephys_file : string
25 The file path to the binary ephys data.
26 ts : ndarray_like
27 The timestamps (in s) of the spikes.
28 ch : ndarray_like
29 The channels on which to extract the waveforms.
30 t : numeric (optional)
31 The time (in ms) of each returned waveform.
32 sr : int (optional)
33 The sampling rate (in hz) that the ephys data was acquired at.
34 n_ch_probe : int (optional)
35 The number of channels of the recording.
36 car: bool (optional)
37 A flag to perform CAR before extracting waveforms.
39 Returns
40 -------
41 waveforms : ndarray
42 An array of shape (#spikes, #samples, #channels) containing the waveforms.
44 Examples
45 --------
46 1) Extract all the waveforms for unit1 with and without CAR.
47 >>> import numpy as np
48 >>> import brainbox as bb
49 >>> import one.alf.io as alfio
50 >>> import ibllib.ephys.spikes as e_spks
51 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
52 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
53 # Get a clusters bunch and a units bunch from a spikes bunch from an alf directory.
54 >>> clstrs_b = alfio.load_object(path_to_alf_out, 'clusters')
55 >>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
56 >>> units_b = bb.processing.get_units_bunch(spks, ['times'])
57 # Get the timestamps and 20 channels around the max amp channel for unit1, and extract the
58 # two sets of waveforms.
59 >>> ts = units_b['times']['1']
60 >>> max_ch = max_ch = clstrs_b['channels'][1]
61 >>> if max_ch < 10: # take only channels greater than `max_ch`.
62 >>> ch = np.arange(max_ch, max_ch + 20)
63 >>> elif (max_ch + 10) > 385: # take only channels less than `max_ch`.
64 >>> ch = np.arange(max_ch - 20, max_ch)
65 >>> else: # take `n_c_ch` around `max_ch`.
66 >>> ch = np.arange(max_ch - 10, max_ch + 10)
67 >>> wf = bb.io.extract_waveforms(path_to_ephys_file, ts, ch, car=False)
68 >>> wf_car = bb.io.extract_waveforms(path_to_ephys_file, ts, ch, car=True)
69 """
71 # Get memmapped array of `ephys_file`
72 with spikeglx.Reader(ephys_file) as s_reader:
73 file_m = s_reader.data # the memmapped array
74 n_wf_samples = int(sr / 1000 * (t / 2)) # number of samples to return on each side of a ts
75 ts_samples = np.array(ts * sr).astype(int) # the samples corresponding to `ts`
76 t_sample_first = ts_samples[0] - n_wf_samples
78 # Exception handling for impossible channels
79 ch = np.asarray(ch)
80 ch = ch.reshape((ch.size, 1)) if ch.size == 1 else ch
81 if np.any(ch < 0) or np.any(ch > n_ch_probe):
82 raise Exception('At least one specified channel number is impossible. '
83 f'The minimum channel number was {np.min(ch)}, '
84 f'and the maximum channel number was {np.max(ch)}. '
85 'Check specified channel numbers and try again.')
87 if car: # compute spatial noise in chunks
88 # see https://github.com/int-brain-lab/iblenv/issues/5
89 raise NotImplementedError("CAR option is not available")
91 # Initialize `waveforms`, extract waveforms from `file_m`, and CAR.
92 waveforms = np.zeros((len(ts), 2 * n_wf_samples, ch.size))
93 # Give time estimate for extracting waveforms.
94 t0 = time.perf_counter()
95 for i in range(5):
96 waveforms[i, :, :] = \
97 file_m[i * n_wf_samples * 2 + t_sample_first:
98 i * n_wf_samples * 2 + t_sample_first + n_wf_samples * 2, ch].reshape(
99 (n_wf_samples * 2, ch.size))
100 dt = time.perf_counter() - t0
101 print('Performing waveform extraction. Estimated time is {:.2f} mins. ({})'
102 .format(dt * len(ts) / 60 / 5, time.ctime()))
103 for spk, _ in enumerate(ts): # extract waveforms
104 spk_ts_sample = ts_samples[spk]
105 spk_samples = np.arange(spk_ts_sample - n_wf_samples, spk_ts_sample + n_wf_samples)
106 # have to reshape to add an axis to broadcast `file_m` into `waveforms`
107 waveforms[spk, :, :] = \
108 file_m[spk_samples[0]:spk_samples[-1] + 1, ch].reshape((spk_samples.size, ch.size))
109 print('Done. ({})'.format(time.ctime()))
111 return waveforms
114class Streamer(spikeglx.Reader):
115 """
116 pid = 'e31b4e39-e350-47a9-aca4-72496d99ff2a'
117 one = ONE()
118 sr = Streamer(pid=pid, one=one)
119 raw_voltage = sr[int(t0 * sr.fs):int((t0 + nsecs) * sr.fs), :]
120 """
121 def __init__(self, pid, one, typ='ap', cache_folder=None, remove_cached=False):
122 self.target_dir = None # last chunk directory download or read 1a
123 self.one = one 1a
124 self.pid = pid 1a
125 self.cache_folder = cache_folder or Path(self.one.alyx._par.CACHE_DIR).joinpath('cache', typ) 1a
126 self.remove_cached = remove_cached 1a
127 self.eid, self.pname = self.one.pid2eid(pid) 1a
128 self.file_chunks = self.one.load_dataset(self.eid, f'*.{typ}.ch', collection=f"*{self.pname}") 1a
129 meta_file = self.one.load_dataset(self.eid, f'*.{typ}.meta', collection=f"*{self.pname}") 1a
130 cbin_rec = self.one.list_datasets(self.eid, collection=f"*{self.pname}", filename=f'*{typ}.*bin', details=True) 1a
131 cbin_rec.index = cbin_rec.index.map(lambda x: (self.eid, x)) 1a
132 self.url_cbin = self.one.record2url(cbin_rec)[0] 1a
133 with open(self.file_chunks, 'r') as f: 1a
134 self.chunks = json.load(f) 1a
135 self.chunks['chunk_bounds'] = np.array(self.chunks['chunk_bounds']) 1a
136 super(Streamer, self).__init__(meta_file, ignore_warnings=True) 1a
138 def read(self, nsel=slice(0, 10000), csel=slice(None), sync=True, volts=True):
139 """
140 overload the read function by downloading the necessary chunks
141 """
142 first_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.start) - 1) 1a
143 last_chunk = np.maximum(0, np.searchsorted(self.chunks['chunk_bounds'], nsel.stop) - 1) 1a
144 n0 = self.chunks['chunk_bounds'][first_chunk] 1a
145 _logger.debug(f'Streamer: caching sample {n0}, (t={n0 / self.fs})') 1a
146 self.cache_folder.mkdir(exist_ok=True, parents=True) 1a
147 sr, file_cbin = self._download_raw_partial(first_chunk=first_chunk, last_chunk=last_chunk) 1a
148 if not volts: 1a
149 data = np.copy(sr._raw[nsel.start - n0:nsel.stop - n0, csel])
150 else:
151 data = sr[nsel.start - n0: nsel.stop - n0, csel] 1a
152 sr.close() 1a
153 if self.remove_cached: 1a
154 shutil.rmtree(self.target_dir)
155 return data 1a
157 def _download_raw_partial(self, first_chunk=0, last_chunk=0):
158 """
159 downloads one or several chunks of a mtscomp file and copy ch files and metadata to return
160 a spikeglx.Reader object
161 :param first_chunk:
162 :param last_chunk:
163 :return: spikeglx.Reader of the current chunk, Pathlib.Path of the directory where it is stored
164 :return: cbin local path
165 """
166 assert str(self.url_cbin).endswith('.cbin') 1a
167 webclient = self.one.alyx 1a
168 relpath = Path(self.url_cbin.replace(webclient._par.HTTP_DATA_SERVER, '.')).parents[0] 1a
169 # write the temp file into a subdirectory
170 tdir_chunk = f"chunk_{str(first_chunk).zfill(6)}_to_{str(last_chunk).zfill(6)}" 1a
171 # for parallel processes, there is a risk of collisions if the removed cached flag is set to True
172 # if the folder is to be removed append a unique identifier to avoid having duplicate names
173 if self.remove_cached: 1a
174 tdir_chunk += ''.join([random.choice(string.ascii_letters) for _ in np.arange(10)])
175 self.target_dir = Path(self.cache_folder, relpath, tdir_chunk) 1a
176 Path(self.target_dir).mkdir(parents=True, exist_ok=True) 1a
177 ch_file_stream = self.target_dir.joinpath(self.file_chunks.name).with_suffix('.stream.ch') 1a
179 # Get the first sample index, and the number of samples to download.
180 i0 = self.chunks['chunk_bounds'][first_chunk] 1a
181 ns_stream = self.chunks['chunk_bounds'][last_chunk + 1] - i0 1a
182 total_samples = self.chunks['chunk_bounds'][-1] 1a
184 # handles the meta file
185 meta_local_path = ch_file_stream.with_suffix('.meta') 1a
186 if not meta_local_path.exists(): 1a
187 shutil.copy(self.file_chunks.with_suffix('.meta'), meta_local_path) 1a
189 # if the cached version happens to be the same as the one on disk, just load it
190 if ch_file_stream.exists() and ch_file_stream.with_suffix('.cbin').exists(): 1a
191 with open(ch_file_stream, 'r') as f: 1a
192 cmeta_stream = json.load(f) 1a
193 if (cmeta_stream.get('chopped_first_sample', None) == i0 and 1a
194 cmeta_stream.get('chopped_total_samples', None) == total_samples):
195 return spikeglx.Reader(ch_file_stream.with_suffix('.cbin'), ignore_warnings=True), ch_file_stream 1a
197 else:
198 shutil.copy(self.file_chunks, ch_file_stream) 1a
199 assert ch_file_stream.exists() 1a
201 cmeta = self.chunks.copy() 1a
202 # prepare the metadata file
203 cmeta['chunk_bounds'] = cmeta['chunk_bounds'][first_chunk:last_chunk + 2] 1a
204 cmeta['chunk_bounds'] = [int(_ - i0) for _ in cmeta['chunk_bounds']] 1a
205 assert len(cmeta['chunk_bounds']) >= 2 1a
206 assert cmeta['chunk_bounds'][0] == 0 1a
208 first_byte = cmeta['chunk_offsets'][first_chunk] 1a
209 cmeta['chunk_offsets'] = cmeta['chunk_offsets'][first_chunk:last_chunk + 2] 1a
210 cmeta['chunk_offsets'] = [_ - first_byte for _ in cmeta['chunk_offsets']] 1a
211 assert len(cmeta['chunk_offsets']) >= 2 1a
212 assert cmeta['chunk_offsets'][0] == 0 1a
213 n_bytes = cmeta['chunk_offsets'][-1] 1a
214 assert n_bytes > 0 1a
216 # Save the chopped chunk bounds and offsets.
217 cmeta['sha1_compressed'] = None 1a
218 cmeta['sha1_uncompressed'] = None 1a
219 cmeta['chopped'] = True 1a
220 cmeta['chopped_first_sample'] = int(i0) 1a
221 cmeta['chopped_samples'] = int(ns_stream) 1a
222 cmeta['chopped_total_samples'] = int(total_samples) 1a
224 with open(ch_file_stream, 'w') as f: 1a
225 json.dump(cmeta, f, indent=2, sort_keys=True) 1a
227 # Download the requested chunks
228 retries = 0 1a
229 while True: 1a
230 try: 1a
231 cbin_local_path = webclient.download_file( 1a
232 self.url_cbin, chunks=(first_byte, n_bytes),
233 target_dir=self.target_dir, clobber=True, return_md5=False)
234 break 1a
235 except Exception as e:
236 retries += 1
237 if retries > 5:
238 raise e
239 _logger.warning(f'Failed to download chunk {first_chunk} to {last_chunk}, retrying')
240 time.sleep(1)
241 cbin_local_path_renamed = remove_uuid_string(cbin_local_path).with_suffix('.stream.cbin') 1a
242 cbin_local_path.replace(cbin_local_path_renamed) 1a
243 assert cbin_local_path_renamed.exists() 1a
245 reader = spikeglx.Reader(cbin_local_path_renamed, ignore_warnings=True) 1a
246 return reader, cbin_local_path 1a