Coverage for ibllib/pipes/ephys_tasks.py: 57%

468 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 09:55 +0000

1import logging 

2from pathlib import Path 

3import re 

4import shutil 

5import subprocess 

6import traceback 

7 

8import packaging.version 

9import numpy as np 

10import pandas as pd 

11import spikeglx 

12import neuropixel 

13from ibldsp.utils import rms 

14from ibldsp.waveform_extraction import extract_wfs_cbin 

15import one.alf.io as alfio 

16import iblutil.util 

17 

18from ibllib.misc import check_nvidia_driver 

19from ibllib.pipes import base_tasks 

20from ibllib.pipes.sync_tasks import SyncPulses 

21from ibllib.ephys import ephysqc 

22import ibllib.ephys.spikes 

23from ibllib.qc.alignment_qc import get_aligned_channels 

24from ibllib.plots.figures import LfpPlots, ApPlots, BadChannelsAp 

25from ibllib.plots.figures import SpikeSorting as SpikeSortingPlots 

26from ibllib.io.extractors.ephys_fpga import extract_sync 

27from ibllib.ephys.spikes import sync_probes 

28 

29 

30_logger = logging.getLogger("ibllib") 

31 

32 

33class EphysRegisterRaw(base_tasks.DynamicTask): 

34 """ 

35 Creates the probe insertions and uploads the probe descriptions file, also compresses the nidq files and uploads 

36 """ 

37 

38 priority = 100 

39 job_size = 'small' 

40 

41 @property 

42 def signature(self): 

43 signature = { 1u

44 'input_files': [], 

45 'output_files': [('probes.description.json', 'alf', True)] 

46 } 

47 return signature 1u

48 

49 def _run(self): 

50 

51 out_files = ibllib.ephys.spikes.probes_description(self.session_path, self.one) 1u

52 

53 return out_files 1u

54 

55 

56class EphysSyncRegisterRaw(base_tasks.DynamicTask): 

57 """ 

58 Task to rename, compress and register raw daq data with .bin format collected using NIDAQ 

59 """ 

60 

61 priority = 90 

62 cpu = 2 

63 io_charge = 30 # this jobs reads raw ap files 

64 job_size = 'small' 

65 

66 @property 

67 def signature(self): 

68 signature = { 1arstopefn

69 'input_files': [('*.*bin', self.sync_collection, True), 

70 ('*.meta', self.sync_collection, True), 

71 ('*.wiring.json', self.sync_collection, True)], 

72 'output_files': [('*nidq.cbin', self.sync_collection, True), 

73 ('*nidq.ch', self.sync_collection, True), 

74 ('*nidq.meta', self.sync_collection, True), 

75 ('*nidq.wiring.json', self.sync_collection, True)] 

76 } 

77 return signature 1arstopefn

78 

79 def _run(self): 

80 

81 out_files = [] 1efn

82 

83 # Detect the wiring file 

84 wiring_file = next(self.session_path.joinpath(self.sync_collection).glob('*.wiring.json'), None) 1efn

85 if wiring_file is not None: 1efn

86 out_files.append(wiring_file) 1efn

87 

88 # Search for .bin files in the sync_collection folder 

89 files = list(self.session_path.joinpath(self.sync_collection).glob('*nidq.*bin')) 1efn

90 bin_file = files[0] if len(files) == 1 else None 1efn

91 

92 # If we don't have a .bin/ .cbin file anymore see if we can still find the .ch and .meta files 

93 if bin_file is None: 1efn

94 for ext in ['ch', 'meta']: 1n

95 files = list(self.session_path.joinpath(self.sync_collection).glob(f'*nidq.{ext}')) 1n

96 ext_file = files[0] if len(files) == 1 else None 1n

97 if ext_file is not None: 1n

98 out_files.append(ext_file) 1n

99 

100 return out_files if len(out_files) > 0 else None 1n

101 

102 # If we do find the .bin file, compress files (only if they haven't already been compressed) 

103 sr = spikeglx.Reader(bin_file) 1ef

104 if sr.is_mtscomp: 1ef

105 sr.close() 1e

106 cbin_file = bin_file 1e

107 assert cbin_file.suffix == '.cbin' 1e

108 else: 

109 cbin_file = sr.compress_file() 1f

110 sr.close() 1f

111 bin_file.unlink() 1f

112 

113 meta_file = cbin_file.with_suffix('.meta') 1ef

114 ch_file = cbin_file.with_suffix('.ch') 1ef

115 

116 out_files.append(cbin_file) 1ef

117 out_files.append(ch_file) 1ef

118 out_files.append(meta_file) 1ef

119 

120 return out_files 1ef

121 

122 

123class EphysCompressNP1(base_tasks.EphysTask): 

124 priority = 90 

125 cpu = 2 

126 io_charge = 100 # this jobs reads raw ap files 

127 job_size = 'small' 

128 

129 @property 

130 def signature(self): 

131 signature = { 1ldm

132 'input_files': [('*ap.meta', f'{self.device_collection}/{self.pname}', True), 

133 ('*ap.*bin', f'{self.device_collection}/{self.pname}', True), 

134 ('*lf.meta', f'{self.device_collection}/{self.pname}', True), 

135 ('*lf.*bin', f'{self.device_collection}/{self.pname}', True), 

136 ('*wiring.json', f'{self.device_collection}/{self.pname}', False)], 

137 'output_files': [('*ap.meta', f'{self.device_collection}/{self.pname}', True), 

138 ('*ap.cbin', f'{self.device_collection}/{self.pname}', True), 

139 ('*ap.ch', f'{self.device_collection}/{self.pname}', True), 

140 ('*lf.meta', f'{self.device_collection}/{self.pname}', True), 

141 ('*lf.cbin', f'{self.device_collection}/{self.pname}', True), 

142 ('*lf.ch', f'{self.device_collection}/{self.pname}', True), 

143 ('*wiring.json', f'{self.device_collection}/{self.pname}', False)] 

144 } 

145 return signature 1ldm

146 

147 def _run(self): 

148 

149 out_files = [] 1ldm

150 

151 # Detect and upload the wiring file 

152 wiring_file = next(self.session_path.joinpath(self.device_collection, self.pname).glob('*wiring.json'), None) 1ldm

153 if wiring_file is not None: 1ldm

154 out_files.append(wiring_file) 1ldm

155 

156 ephys_files = spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, self.pname)) 1ldm

157 ephys_files += spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, self.pname), ext="ch") 1ldm

158 ephys_files += spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, self.pname), ext="meta") 1ldm

159 

160 for ef in ephys_files: 1ldm

161 for typ in ["ap", "lf"]: 1ldm

162 bin_file = ef.get(typ) 1ldm

163 if not bin_file: 1ldm

164 continue 

165 if bin_file.suffix.find("bin") == 1: 1ldm

166 with spikeglx.Reader(bin_file) as sr: 1d

167 if sr.is_mtscomp: 1d

168 out_files.append(bin_file) 

169 else: 

170 _logger.info(f"Compressing binary file {bin_file}") 1d

171 cbin_file = sr.compress_file() 1d

172 sr.close() 1d

173 bin_file.unlink() 1d

174 out_files.append(cbin_file) 1d

175 out_files.append(cbin_file.with_suffix('.ch')) 1d

176 else: 

177 out_files.append(bin_file) 1ldm

178 

179 return out_files 1ldm

180 

181 

182class EphysCompressNP21(base_tasks.EphysTask): 

183 priority = 90 

184 cpu = 2 

185 io_charge = 100 # this jobs reads raw ap files 

186 job_size = 'large' 

187 

188 @property 

189 def signature(self): 

190 signature = { 1ghi

191 'input_files': [('*ap.meta', f'{self.device_collection}/{self.pname}', True), 

192 ('*ap.*bin', f'{self.device_collection}/{self.pname}', True), 

193 ('*wiring.json', f'{self.device_collection}/{self.pname}', False)], 

194 'output_files': [('*ap.meta', f'{self.device_collection}/{self.pname}', True), 

195 ('*ap.cbin', f'{self.device_collection}/{self.pname}', True), 

196 ('*ap.ch', f'{self.device_collection}/{self.pname}', True), 

197 ('*lf.meta', f'{self.device_collection}/{self.pname}', True), 

198 ('*lf.cbin', f'{self.device_collection}/{self.pname}', True), 

199 ('*lf.ch', f'{self.device_collection}/{self.pname}', True), 

200 ('*wiring.json', f'{self.device_collection}/{self.pname}', False)] 

201 } 

202 return signature 1ghi

203 

204 def _run(self): 

205 

206 out_files = [] 1ghi

207 # Detect wiring files 

208 wiring_file = next(self.session_path.joinpath(self.device_collection, self.pname).glob('*wiring.json'), None) 1ghi

209 if wiring_file is not None: 1ghi

210 out_files.append(wiring_file) 1ghi

211 

212 ephys_files = spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, self.pname)) 1ghi

213 if len(ephys_files) > 0: 1ghi

214 bin_file = ephys_files[0].get('ap', None) 1gh

215 

216 # This is the case where no ap.bin/.cbin file exists 

217 if len(ephys_files) == 0 or not bin_file: 1ghi

218 ephys_files += spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, self.pname), ext="ch") 1i

219 ephys_files += spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, self.pname), ext="meta") 1i

220 for ef in ephys_files: 1i

221 for typ in ["ap", "lf"]: 1i

222 bin_file = ef.get(typ) 1i

223 if bin_file: 1i

224 out_files.append(bin_file) 1i

225 

226 return out_files 1i

227 

228 # If the ap.bin / .cbin file does exists instantiate the NP2converter 

229 np_conv = neuropixel.NP2Converter(bin_file, compress=True) 1gh

230 np_conv_status = np_conv.process() 1gh

231 np_conv_files = np_conv.get_processed_files_NP21() 1gh

232 np_conv.sr.close() 1gh

233 

234 # Status = 1 - successfully complete 

235 if np_conv_status == 1: # This is the status that it has completed successfully 1gh

236 out_files += np_conv_files 1gh

237 return out_files 1gh

238 # Status = 0 - Already processed 

239 elif np_conv_status == 0: 

240 ephys_files = spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, self.pname)) 

241 ephys_files += spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, self.pname), ext="ch") 

242 ephys_files += spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, self.pname), ext="meta") 

243 for ef in ephys_files: 

244 for typ in ["ap", "lf"]: 

245 bin_file = ef.get(typ) 

246 if bin_file and bin_file.suffix != '.bin': 

247 out_files.append(bin_file) 

248 

249 return out_files 

250 

251 else: 

252 return 

253 

254 

255class EphysCompressNP24(base_tasks.EphysTask): 

256 """ 

257 Compresses NP2.4 data by splitting into N binary files, corresponding to N shanks 

258 :param pname: a probe name string 

259 :param device_collection: the collection containing the probes (usually 'raw_ephys_data') 

260 :param nshanks: number of shanks used (usually 4 but it may be less depending on electrode map), optional 

261 """ 

262 

263 priority = 90 

264 cpu = 2 

265 io_charge = 100 # this jobs reads raw ap files 

266 job_size = 'large' 

267 

268 def __init__(self, session_path, *args, pname=None, device_collection='raw_ephys_data', nshanks=None, **kwargs): 

269 assert pname, "pname is a required argument" 1opcb

270 if nshanks is None: 1opcb

271 meta_file = next(session_path.joinpath(device_collection, pname).glob('*ap.meta')) 1b

272 nshanks = spikeglx._get_nshanks_from_meta(spikeglx.read_meta_data(meta_file)) 1b

273 assert nshanks > 1 1opcb

274 super(EphysCompressNP24, self).__init__( 1opcb

275 session_path, *args, pname=pname, device_collection=device_collection, nshanks=nshanks, **kwargs) 

276 

277 @property 

278 def signature(self): 

279 

280 signature = { 1cb

281 'input_files': [('*ap.meta', f'{self.device_collection}/{self.pname}', True), 

282 ('*ap.*bin', f'{self.device_collection}/{self.pname}', True), 

283 ('*wiring.json', f'{self.device_collection}/{self.pname}', False)], 

284 'output_files': [('*ap.meta', f'{self.device_collection}/{self.pname}{pext}', True) for pext in self.pextra] + 

285 [('*ap.cbin', f'{self.device_collection}/{self.pname}{pext}', True) for pext in self.pextra] + 

286 [('*ap.ch', f'{self.device_collection}/{self.pname}{pext}', True) for pext in self.pextra] + 

287 [('*lf.meta', f'{self.device_collection}/{self.pname}{pext}', True) for pext in self.pextra] + 

288 [('*lf.cbin', f'{self.device_collection}/{self.pname}{pext}', True) for pext in self.pextra] + 

289 [('*lf.ch', f'{self.device_collection}/{self.pname}{pext}', True) for pext in self.pextra] + 

290 [('*wiring.json', f'{self.device_collection}/{self.pname}{pext}', False) for pext in self.pextra] 

291 } 

292 return signature 1cb

293 

294 def _run(self, delete_original=True): 

295 

296 # Do we need the ability to register the files once it already been processed and original file deleted? 

297 

298 files = spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, self.pname)) 1cb

299 assert len(files) == 1 1cb

300 bin_file = files[0].get('ap', None) 1cb

301 

302 np_conv = neuropixel.NP2Converter(bin_file, post_check=True, compress=True, delete_original=delete_original) 1cb

303 np_conv_status = np_conv.process() 1cb

304 out_files = np_conv.get_processed_files_NP24() 1cb

305 np_conv.sr.close() 1cb

306 

307 if np_conv_status == 1: 1cb

308 wiring_file = next(self.session_path.joinpath(self.device_collection, self.pname).glob('*wiring.json'), None) 1cb

309 if wiring_file is not None: 1cb

310 # copy wiring file to each sub probe directory and add to output files 

311 for pext in self.pextra: 1c

312 new_wiring_file = self.session_path.joinpath(self.device_collection, f'{self.pname}{pext}', wiring_file.name) 1c

313 shutil.copyfile(wiring_file, new_wiring_file) 1c

314 out_files.append(new_wiring_file) 1c

315 return out_files 1cb

316 elif np_conv_status == 0: 1c

317 for pext in self.pextra: 1c

318 ephys_files = spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, f'{self.pname}{pext}')) 1c

319 ephys_files += spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, 1c

320 f'{self.pname}{pext}'), ext="ch") 

321 ephys_files += spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, 1c

322 f'{self.pname}{pext}'), ext="meta") 

323 for ef in ephys_files: 1c

324 for typ in ["ap", "lf"]: 1c

325 bin_file = ef.get(typ) 1c

326 if bin_file and bin_file.suffix != '.bin': 1c

327 out_files.append(bin_file) 1c

328 

329 wiring_file = next(self.session_path.joinpath(self.device_collection, 1c

330 f'{self.pname}{pext}').glob('*wiring.json'), None) 

331 if wiring_file is None: 1c

332 # See if we have the original wiring file 

333 orig_wiring_file = next(self.session_path.joinpath(self.device_collection, 

334 self.pname).glob('*wiring.json'), None) 

335 if orig_wiring_file is not None: 

336 # copy wiring file to sub probe directory and add to output files 

337 new_wiring_file = self.session_path.joinpath(self.device_collection, f'{self.pname}{pext}', 

338 orig_wiring_file.name) 

339 shutil.copyfile(orig_wiring_file, new_wiring_file) 

340 out_files.append(new_wiring_file) 

341 else: 

342 out_files.append(wiring_file) 1c

343 

344 return out_files 1c

345 else: 

346 return 

347 

348 

349class EphysSyncPulses(SyncPulses): 

350 

351 priority = 90 

352 cpu = 2 

353 io_charge = 30 # this jobs reads raw ap files 

354 job_size = 'small' 

355 

356 @property 

357 def signature(self): 

358 signature = { 1arstopv

359 'input_files': [('*nidq.cbin', self.sync_collection, False), 

360 ('*nidq.ch', self.sync_collection, False), 

361 ('*nidq.meta', self.sync_collection, False), 

362 ('*nidq.wiring.json', self.sync_collection, True)], 

363 'output_files': [('_spikeglx_sync.times.npy', self.sync_collection, True), 

364 ('_spikeglx_sync.polarities.npy', self.sync_collection, True), 

365 ('_spikeglx_sync.channels.npy', self.sync_collection, True)] 

366 } 

367 

368 return signature 1arstopv

369 

370 

371class EphysPulses(base_tasks.EphysTask): 

372 """ 

373 Extract Pulses from raw electrophysiology data into numpy arrays 

374 Perform the probes synchronisation with nidq (3B) or main probe (3A) 

375 First the job extract the sync pulses from the synchronisation task in all probes, and then perform the 

376 synchronisation with the nidq 

377 

378 :param pname: a list of probes names or a single probe name string 

379 :param device_collection: the collection containing the probes (usually 'raw_ephys_data') 

380 :param sync_collection: the collection containing the synchronisation device - nidq (usually 'raw_ephys_data') 

381 """ 

382 priority = 90 

383 cpu = 2 

384 io_charge = 30 # this jobs reads raw ap files 

385 job_size = 'small' 

386 

387 def __init__(self, *args, **kwargs): 

388 super(EphysPulses, self).__init__(*args, **kwargs) 1arstopbjk

389 assert self.device_collection, "device_collection is a required argument" 1arstopbjk

390 assert self.sync_collection, "sync_collection is a required argument" 1arstopbjk

391 self.pname = [self.pname] if isinstance(self.pname, str) else self.pname 1arstopbjk

392 assert isinstance(self.pname, list), 'pname task argument should be a list or a string' 1arstopbjk

393 

394 @property 

395 def signature(self): 

396 signature = { 1bjk

397 'input_files': 

398 [('*ap.meta', f'{self.device_collection}/{pname}', True) for pname in self.pname] + 

399 [('*ap.cbin', f'{self.device_collection}/{pname}', True) for pname in self.pname] + 

400 [('*ap.ch', f'{self.device_collection}/{pname}', True) for pname in self.pname] + 

401 [('*ap.wiring.json', f'{self.device_collection}/{pname}', False) for pname in self.pname] + 

402 [('_spikeglx_sync.times.*npy', f'{self.device_collection}/{pname}', False) for pname in self.pname] + 

403 [('_spikeglx_sync.polarities.*npy', f'{self.device_collection}/{pname}', False) for pname in self.pname] + 

404 [('_spikeglx_sync.channels.*npy', f'{self.device_collection}/{pname}', False) for pname in self.pname] + 

405 [('_spikeglx_sync.times.*npy', self.sync_collection, True), 

406 ('_spikeglx_sync.polarities.*npy', self.sync_collection, True), 

407 ('_spikeglx_sync.channels.*npy', self.sync_collection, True), 

408 ('*ap.meta', self.sync_collection, True) 

409 ], 

410 'output_files': [(f'_spikeglx_sync.times.{pname}.npy', f'{self.device_collection}/{pname}', True) 

411 for pname in self.pname] + 

412 [(f'_spikeglx_sync.polarities.{pname}.npy', f'{self.device_collection}/{pname}', True) 

413 for pname in self.pname] + 

414 [(f'_spikeglx_sync.channels.{pname}.npy', f'{self.device_collection}/{pname}', True) 

415 for pname in self.pname] + 

416 [('*sync.npy', f'{self.device_collection}/{pname}', True) for pname in 

417 self.pname] + 

418 [('*timestamps.npy', f'{self.device_collection}/{pname}', True) for pname in 

419 self.pname] 

420 } 

421 

422 return signature 1bjk

423 

424 def _run(self, overwrite=False): 

425 

426 out_files = [] 1bjk

427 for probe in self.pname: 1bjk

428 files = spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, probe)) 1bjk

429 assert len(files) == 1 # will error if the session is split 1bjk

430 bin_file = files[0].get('ap', None) 1bjk

431 if not bin_file: 1bjk

432 return [] 

433 _, out = extract_sync(self.session_path, ephys_files=files, overwrite=overwrite) 1bjk

434 out_files += out 1bjk

435 

436 status, sync_files = sync_probes.sync(self.session_path, probe_names=self.pname) 1bjk

437 

438 return out_files + sync_files 1bjk

439 

440 

441class RawEphysQC(base_tasks.EphysTask): 

442 

443 cpu = 2 

444 io_charge = 30 # this jobs reads raw ap files 

445 priority = 10 # a lot of jobs depend on this one 

446 job_size = 'small' 

447 

448 @property 

449 def signature(self): 

450 signature = { 

451 'input_files': [('*ap.meta', f'{self.device_collection}/{self.pname}', True), 

452 ('*lf.meta', f'{self.device_collection}/{self.pname}', True), 

453 ('*lf.ch', f'{self.device_collection}/{self.pname}', False), 

454 ('*lf.*bin', f'{self.device_collection}/{self.pname}', False)], 

455 'output_files': [('_iblqc_ephysChannels.apRMS.npy', f'{self.device_collection}/{self.pname}', True), 

456 ('_iblqc_ephysChannels.rawSpikeRates.npy', f'{self.device_collection}/{self.pname}', True), 

457 ('_iblqc_ephysChannels.labels.npy', f'{self.device_collection}/{self.pname}', True), 

458 ('_iblqc_ephysSpectralDensityLF.freqs.npy', f'{self.device_collection}/{self.pname}', True), 

459 ('_iblqc_ephysSpectralDensityLF.power.npy', f'{self.device_collection}/{self.pname}', True), 

460 ('_iblqc_ephysSpectralDensityAP.freqs.npy', f'{self.device_collection}/{self.pname}', True), 

461 ('_iblqc_ephysSpectralDensityAP.power.npy', f'{self.device_collection}/{self.pname}', True), 

462 ('_iblqc_ephysTimeRmsLF.rms.npy', f'{self.device_collection}/{self.pname}', True), 

463 ('_iblqc_ephysTimeRmsLF.timestamps.npy', f'{self.device_collection}/{self.pname}', True)] 

464 } 

465 return signature 

466 

467 # TODO make sure this works with NP2 probes (at the moment not sure it will due to raiseError mapping) 

468 def _run(self, overwrite=False): 

469 

470 eid = self.one.path2eid(self.session_path) 

471 probe = self.one.alyx.rest('insertions', 'list', session=eid, name=self.pname) 

472 

473 # We expect there to only be one probe 

474 if len(probe) != 1: 

475 _logger.warning(f"{self.pname} for {eid} not found") # Should we create it? 

476 self.status = -1 

477 return 

478 

479 pid = probe[0]['id'] 

480 qc_files = [] 

481 _logger.info(f"\nRunning QC for probe insertion {self.pname}") 

482 try: 

483 eqc = ephysqc.EphysQC(pid, session_path=self.session_path, one=self.one) 

484 qc_files.extend(eqc.run(update=True, overwrite=overwrite)) 

485 _logger.info("Creating LFP QC plots") 

486 plot_task = LfpPlots(pid, session_path=self.session_path, one=self.one) 

487 _ = plot_task.run() 

488 self.plot_tasks.append(plot_task) 

489 plot_task = BadChannelsAp(pid, session_path=self.session_path, one=self.one) 

490 _ = plot_task.run() 

491 self.plot_tasks.append(plot_task) 

492 

493 except AssertionError: 

494 _logger.error(traceback.format_exc()) 

495 self.status = -1 

496 

497 return qc_files 

498 

499 

500class CellQCMixin: 

501 """ 

502 This mixin class is used to compute the cell QC metrics and update the json field of the probe insertion 

503 The compute_cell_qc method is static and can be used independently. 

504 """ 

505 @staticmethod 

506 def compute_cell_qc(folder_alf_probe): 

507 """ 

508 Computes the cell QC given an extracted probe alf path 

509 :param folder_alf_probe: folder 

510 :return: 

511 """ 

512 # compute the straight qc 

513 _logger.info(f"Computing cluster qc for {folder_alf_probe}") 

514 spikes = alfio.load_object(folder_alf_probe, 'spikes') 

515 clusters = alfio.load_object(folder_alf_probe, 'clusters') 

516 df_units, drift = ephysqc.spike_sorting_metrics( 

517 spikes.times, spikes.clusters, spikes.amps, spikes.depths, 

518 cluster_ids=np.arange(clusters.channels.size)) 

519 # if the ks2 labels file exist, load them and add the column 

520 file_labels = folder_alf_probe.joinpath('cluster_KSLabel.tsv') 

521 if file_labels.exists(): 

522 ks2_labels = pd.read_csv(file_labels, sep='\t') 

523 ks2_labels.rename(columns={'KSLabel': 'ks2_label'}, inplace=True) 

524 df_units = pd.concat( 

525 [df_units, ks2_labels['ks2_label'].reindex(df_units.index)], axis=1) 

526 # save as parquet file 

527 df_units.to_parquet(file_metrics := folder_alf_probe.joinpath("clusters.metrics.pqt")) 

528 

529 assert np.all((df_units['bitwise_fail'] == 0) == (df_units['label'] == 1)) # useless but sanity check for OW 

530 

531 cok = df_units['bitwise_fail'] == 0 

532 sok = cok[spikes['clusters']].values 

533 spikes['templates'] = spikes['templates'].astype(np.uint16) 

534 spikes['clusters'] = spikes['clusters'].astype(np.uint16) 

535 spikes['depths'] = spikes['depths'].astype(np.float32) 

536 spikes['amps'] = spikes['amps'].astype(np.float32) 

537 file_passing = folder_alf_probe.joinpath('passingSpikes.table.pqt') 

538 df_spikes = pd.DataFrame(spikes) 

539 df_spikes = df_spikes.iloc[sok, :].reset_index(drop=True) 

540 df_spikes.to_parquet(file_passing) 

541 

542 return [file_metrics, file_passing], df_units, drift 

543 

544 def _label_probe_qc(self, folder_probe, df_units, drift): 

545 """ 

546 Labels the json field of the alyx corresponding probe insertion 

547 :param folder_probe: 

548 :param df_units: 

549 :param drift: 

550 :return: 

551 """ 

552 eid = self.one.path2eid(self.session_path, query_type='remote') 

553 pdict = self.one.alyx.rest('insertions', 'list', session=eid, name=self.pname, no_cache=True) 

554 if len(pdict) != 1: 

555 _logger.warning(f'No probe found for probe name: {self.pname}') 

556 return 

557 isok = df_units['label'] == 1 

558 qcdict = {'n_units': int(df_units.shape[0]), 

559 'n_units_qc_pass': int(np.sum(isok)), 

560 'firing_rate_max': np.max(df_units['firing_rate'][isok]), 

561 'firing_rate_median': np.median(df_units['firing_rate'][isok]), 

562 'amplitude_max_uV': np.max(df_units['amp_max'][isok]) * 1e6, 

563 'amplitude_median_uV': np.max(df_units['amp_median'][isok]) * 1e6, 

564 'drift_rms_um': rms(drift['drift_um']), 

565 } 

566 file_wm = folder_probe.joinpath('_kilosort_whitening.matrix.npy') 

567 if file_wm.exists(): 

568 wm = np.load(file_wm) 

569 qcdict['whitening_matrix_conditioning'] = np.linalg.cond(wm) 

570 # groom qc dict (this function will eventually go directly into the json field update) 

571 for k in qcdict: 

572 if isinstance(qcdict[k], np.int64): 

573 qcdict[k] = int(qcdict[k]) 

574 elif isinstance(qcdict[k], float): 

575 qcdict[k] = np.round(qcdict[k], 2) 

576 self.one.alyx.json_field_update("insertions", pdict[0]["id"], "json", qcdict) 

577 

578 

579class SpikeSorting(base_tasks.EphysTask, CellQCMixin): 

580 """ 

581 Pykilosort 2.5 pipeline 

582 """ 

583 gpu = 1 

584 io_charge = 100 # this jobs reads raw ap files 

585 priority = 60 

586 job_size = 'large' 

587 force = True 

588 env = 'iblsorter' 

589 _sortername = 'iblsorter' 

590 SHELL_SCRIPT = Path.home().joinpath( 

591 f"Documents/PYTHON/iblscripts/deploy/serverpc/{_sortername}/sort_recording.sh" 

592 ) 

593 SPIKE_SORTER_NAME = 'iblsorter' 

594 SORTER_REPOSITORY = Path.home().joinpath('Documents/PYTHON/SPIKE_SORTING/ibl-sorter') 

595 

596 @property 

597 def signature(self): 

598 signature = { 

599 'input_files': [ 

600 ('*ap.meta', f'{self.device_collection}/{self.pname}', True), 

601 ('*ap.*bin', f'{self.device_collection}/{self.pname}', True), 

602 ('*ap.ch', f'{self.device_collection}/{self.pname}', False), 

603 ('*sync.npy', f'{self.device_collection}/{self.pname}', True) 

604 ], 

605 'output_files': [ 

606 # ./raw_ephys_data/{self.pname}/ 

607 ('_iblqc_ephysTimeRmsAP.rms.npy', f'{self.device_collection}/{self.pname}/', True), 

608 ('_iblqc_ephysTimeRmsAP.timestamps.npy', f'{self.device_collection}/{self.pname}/', True), 

609 ('_iblqc_ephysSaturation.samples.npy', f'{self.device_collection}/{self.pname}/', True), 

610 # ./spike_sorters/iblsorter/{self.pname} 

611 ('_kilosort_raw.output.tar', f'spike_sorters/{self._sortername}/{self.pname}/', True), 

612 # ./alf/{self.pname}/iblsorter 

613 (f'_ibl_log.info_{self.SPIKE_SORTER_NAME}.log', f'alf/{self.pname}/{self._sortername}', True), 

614 ('_kilosort_whitening.matrix.npy', f'alf/{self.pname}/{self._sortername}/', True), 

615 ('_phy_spikes_subset.channels.npy', f'alf/{self.pname}/{self._sortername}/', True), 

616 ('_phy_spikes_subset.spikes.npy', f'alf/{self.pname}/{self._sortername}/', True), 

617 ('_phy_spikes_subset.waveforms.npy', f'alf/{self.pname}/{self._sortername}/', True), 

618 ('channels.labels.npy', f'alf/{self.pname}/{self._sortername}/', True), 

619 ('channels.localCoordinates.npy', f'alf/{self.pname}/{self._sortername}/', True), 

620 ('channels.rawInd.npy', f'alf/{self.pname}/{self._sortername}/', True), 

621 ('clusters.amps.npy', f'alf/{self.pname}/{self._sortername}/', True), 

622 ('clusters.channels.npy', f'alf/{self.pname}/{self._sortername}/', True), 

623 ('clusters.depths.npy', f'alf/{self.pname}/{self._sortername}/', True), 

624 ('clusters.metrics.pqt', f'alf/{self.pname}/{self._sortername}/', True), 

625 ('clusters.peakToTrough.npy', f'alf/{self.pname}/{self._sortername}/', True), 

626 ('clusters.uuids.csv', f'alf/{self.pname}/{self._sortername}/', True), 

627 ('clusters.waveforms.npy', f'alf/{self.pname}/{self._sortername}/', True), 

628 ('clusters.waveformsChannels.npy', f'alf/{self.pname}/{self._sortername}/', True), 

629 ('drift.times.npy', f'alf/{self.pname}/{self._sortername}/', True), 

630 ('drift.um.npy', f'alf/{self.pname}/{self._sortername}/', True), 

631 ('drift_depths.um.npy', f'alf/{self.pname}/{self._sortername}/', True), 

632 ('passingSpikes.table.pqt', f'alf/{self.pname}/{self._sortername}/', True), 

633 ('spikes.amps.npy', f'alf/{self.pname}/{self._sortername}/', True), 

634 ('spikes.clusters.npy', f'alf/{self.pname}/{self._sortername}/', True), 

635 ('spikes.depths.npy', f'alf/{self.pname}/{self._sortername}/', True), 

636 ('spikes.samples.npy', f'alf/{self.pname}/{self._sortername}/', True), 

637 ('spikes.templates.npy', f'alf/{self.pname}/{self._sortername}/', True), 

638 ('spikes.times.npy', f'alf/{self.pname}/{self._sortername}/', True), 

639 ('templates.amps.npy', f'alf/{self.pname}/{self._sortername}/', True), 

640 ('templates.waveforms.npy', f'alf/{self.pname}/{self._sortername}/', True), 

641 ('templates.waveformsChannels.npy', f'alf/{self.pname}/{self._sortername}/', True), 

642 ('waveforms.channels.npz', f'alf/{self.pname}/{self._sortername}/', True), 

643 ('waveforms.table.pqt', f'alf/{self.pname}/{self._sortername}/', True), 

644 ('waveforms.templates.npy', f'alf/{self.pname}/{self._sortername}/', True), 

645 ('waveforms.traces.npy', f'alf/{self.pname}/{self._sortername}/', True), 

646 ], 

647 } 

648 return signature 

649 

650 @property 

651 def scratch_folder_run(self): 

652 """ 

653 Constructs a path to a temporary folder for the spike sorting output and scratch files 

654 This is usually on a high performance drive, and we should factor around 2.5 times the uncompressed raw recording size 

655 For a scratch drive at /mnt/h0 we would have the following temp dir: 

656 /mnt/h0/iblsorter_1.8.0_CSHL071_2020-10-04_001_probe01/ 

657 """ 

658 # get the scratch drive from the shell script 

659 if self.scratch_folder is None: 

660 with open(self.SHELL_SCRIPT) as fid: 

661 lines = fid.readlines() 

662 line = [line for line in lines if line.startswith("SCRATCH_DRIVE=")][0] 

663 m = re.search(r"\=(.*?)(\#|\n)", line)[0] 

664 scratch_drive = Path(m[1:-1].strip()) 

665 else: 

666 scratch_drive = self.scratch_folder 

667 assert scratch_drive.exists(), f"Scratch drive {scratch_drive} not found" 

668 # get the version of the sorter 

669 self.version = self._fetch_iblsorter_version(self.SORTER_REPOSITORY) 

670 spikesorter_dir = f"{self.version}_{'_'.join(list(self.session_path.parts[-3:]))}_{self.pname}" 

671 return scratch_drive.joinpath(spikesorter_dir) 

672 

673 @staticmethod 

674 def _sample2v(ap_file): 

675 md = spikeglx.read_meta_data(ap_file.with_suffix(".meta")) 

676 s2v = spikeglx._conversion_sample2v_from_meta(md) 

677 return s2v["ap"][0] 

678 

679 @staticmethod 

680 def _fetch_iblsorter_version(repo_path): 

681 try: 

682 import iblsorter 

683 return f"iblsorter_{iblsorter.__version__}" 

684 except ImportError: 

685 _logger.info('IBL-sorter not in environment, trying to locate the repository') 

686 init_file = Path(repo_path).joinpath('iblsorter', '__init__.py') 

687 try: 

688 with open(init_file) as fid: 

689 lines = fid.readlines() 

690 for line in lines: 

691 if line.startswith("__version__ = "): 

692 version = line.split('=')[-1].strip().replace('"', '').replace("'", '') 

693 except Exception: 

694 pass 

695 return f"iblsorter_{version}" 

696 

697 @staticmethod 

698 def _fetch_iblsorter_run_version(log_file): 

699 """ 

700 Parse the following line (2 formats depending on version) from the log files to get the version 

701 '\x1b[0m15:39:37.919 [I] ibl:90 Starting Pykilosort version ibl_1.2.1, output in gnagga^[[0m\n' 

702 '\x1b[0m15:39:37.919 [I] ibl:90 Starting Pykilosort version ibl_1.3.0^[[0m\n' 

703 """ 

704 with open(log_file) as fid: 1q

705 for m in range(50): 1q

706 line = fid.readline() 1q

707 print(line.strip()) 1q

708 version = re.search('version (.*)', line) 1q

709 if not line or version: 1q

710 break 1q

711 if version is not None: 1q

712 version = re.sub(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])', '', version.group(1)) 1q

713 version = version.replace(',', ' ').split(' ')[0] # breaks the string after the first space 1q

714 return version 1q

715 

716 def _run_iblsort(self, ap_file): 

717 """ 

718 Runs the ks2 matlab spike sorting for one probe dataset 

719 the raw spike sorting output is in session_path/spike_sorters/{self.SPIKE_SORTER_NAME}/probeXX folder 

720 (discontinued support for old spike sortings in the probe folder <1.5.5) 

721 :return: path of the folder containing ks2 spike sorting output 

722 """ 

723 iblutil.util.setup_logger('iblsorter', level='INFO') 

724 sorter_dir = self.session_path.joinpath("spike_sorters", self.SPIKE_SORTER_NAME, self.pname) 

725 self.FORCE_RERUN = False 

726 if not self.FORCE_RERUN: 

727 log_file = sorter_dir.joinpath(f"_ibl_log.info_{self.SPIKE_SORTER_NAME}.log") 

728 if log_file.exists(): 

729 run_version = self._fetch_iblsorter_run_version(log_file) 

730 if packaging.version.parse(run_version) >= packaging.version.parse('1.7.0'): 

731 _logger.info(f"Already ran: {log_file}" 

732 f" found in {sorter_dir}, skipping.") 

733 return sorter_dir 

734 else: 

735 self.FORCE_RERUN = True 

736 self.scratch_folder_run.mkdir(parents=True, exist_ok=True) 

737 check_nvidia_driver() 

738 try: 

739 # if pykilosort is in the environment, use the installed version within the task 

740 import iblsorter.ibl # noqa 

741 iblsorter.ibl.run_spike_sorting_ibl(bin_file=ap_file, scratch_dir=self.scratch_folder_run, delete=False) 

742 except ImportError: 

743 command2run = f"{self.SHELL_SCRIPT} {ap_file} {self.scratch_folder_run}" 

744 _logger.info(command2run) 

745 process = subprocess.Popen( 

746 command2run, 

747 shell=True, 

748 stdout=subprocess.PIPE, 

749 stderr=subprocess.PIPE, 

750 executable="/bin/bash", 

751 ) 

752 info, error = process.communicate() 

753 info_str = info.decode("utf-8").strip() 

754 _logger.info(info_str) 

755 if process.returncode != 0: 

756 error_str = error.decode("utf-8").strip() 

757 # try and get the kilosort log if any 

758 for log_file in self.scratch_folder_run.rglob('*_kilosort.log'): 

759 with open(log_file) as fid: 

760 log = fid.read() 

761 _logger.error(log) 

762 break 

763 raise RuntimeError(f"{self.SPIKE_SORTER_NAME} {info_str}, {error_str}") 

764 shutil.copytree(self.scratch_folder_run.joinpath('output'), sorter_dir, dirs_exist_ok=True) 

765 return sorter_dir 

766 

767 def _run(self): 

768 """ 

769 Multiple steps. For each probe: 

770 - Runs ks2 (skips if it already ran) 

771 - synchronize the spike sorting 

772 - output the probe description files 

773 - compute the waveforms 

774 :return: list of files to be registered on database 

775 """ 

776 efiles = spikeglx.glob_ephys_files(self.session_path.joinpath(self.device_collection, self.pname)) 

777 ap_files = [(ef.get("ap"), ef.get("label")) for ef in efiles if "ap" in ef.keys()] 

778 assert len(ap_files) != 0, f"No ap file found for probe {self.session_path.joinpath(self.device_collection, self.pname)}" 

779 assert len(ap_files) == 1, f"Several bin files found for the same probe {ap_files}" 

780 ap_file, label = ap_files[0] 

781 out_files = [] 

782 sorter_dir = self._run_iblsort(ap_file) # runs the sorter, skips if it already ran 

783 # convert the data to ALF in the ./alf/probeXX/SPIKE_SORTER_NAME folder 

784 probe_out_path = self.session_path.joinpath("alf", label, self.SPIKE_SORTER_NAME) 

785 shutil.rmtree(probe_out_path, ignore_errors=True) 

786 probe_out_path.mkdir(parents=True, exist_ok=True) 

787 ibllib.ephys.spikes.ks2_to_alf( 

788 sorter_dir, 

789 bin_path=ap_file.parent, 

790 out_path=probe_out_path, 

791 bin_file=ap_file, 

792 ampfactor=self._sample2v(ap_file), 

793 ) 

794 logfile = sorter_dir.joinpath(f"_ibl_log.info_{self.SPIKE_SORTER_NAME}.log") 

795 if logfile.exists(): 

796 shutil.copyfile(logfile, probe_out_path.joinpath(f"_ibl_log.info_{self.SPIKE_SORTER_NAME}.log")) 

797 # recover the QC files from the spike sorting output and copy them 

798 for file_qc in sorter_dir.glob('_iblqc_*.npy'): 

799 shutil.move(file_qc, file_qc_out := ap_file.parent.joinpath(file_qc.name)) 

800 out_files.append(file_qc_out) 

801 # Sync spike sorting with the main behaviour clock: the nidq for 3B+ and the main probe for 3A 

802 out, _ = ibllib.ephys.spikes.sync_spike_sorting(ap_file=ap_file, out_path=probe_out_path) 

803 out_files.extend(out) 

804 # Now compute the unit metrics 

805 files_qc, df_units, drift = self.compute_cell_qc(probe_out_path) 

806 out_files.extend(files_qc) 

807 # convert ks2_output into tar file and also register 

808 # Make this in case spike sorting is in old raw_ephys_data folders, for new 

809 # sessions it should already exist 

810 tar_dir = self.session_path.joinpath('spike_sorters', self.SPIKE_SORTER_NAME, label) 

811 tar_dir.mkdir(parents=True, exist_ok=True) 

812 out = ibllib.ephys.spikes.ks2_to_tar(sorter_dir, tar_dir, force=self.FORCE_RERUN) 

813 out_files.extend(out) 

814 # run waveform extraction 

815 _logger.info(f"Cleaning up temporary folder {self.scratch_folder_run}") 

816 shutil.rmtree(self.scratch_folder_run, ignore_errors=True) 

817 _logger.info("Running waveform extraction") 

818 spikes = alfio.load_object(probe_out_path, 'spikes', attribute=['samples', 'clusters']) 

819 clusters = alfio.load_object(probe_out_path, 'clusters', attribute=['channels']) 

820 channels = alfio.load_object(probe_out_path, 'channels') 

821 _output_waveform_files = extract_wfs_cbin( 

822 bin_file=ap_file, 

823 output_dir=probe_out_path, 

824 spike_samples=spikes['samples'], 

825 spike_clusters=spikes['clusters'], 

826 spike_channels=clusters['channels'][spikes['clusters']], 

827 channel_labels=channels['labels'], 

828 max_wf=256, 

829 trough_offset=42, 

830 spike_length_samples=128, 

831 chunksize_samples=int(30_000), 

832 n_jobs=None, 

833 wfs_dtype=np.float16, 

834 preprocess_steps=["phase_shift", "bad_channel_interpolation", "butterworth", "car"], 

835 scratch_dir=self.scratch_folder_run, 

836 ) 

837 out_files.extend(_output_waveform_files) 

838 _logger.info(f"Cleaning up temporary folder {self.scratch_folder_run}") 

839 shutil.rmtree(self.scratch_folder_run, ignore_errors=True) 

840 if self.one: 

841 eid = self.one.path2eid(self.session_path, query_type='remote') 

842 ins = self.one.alyx.rest('insertions', 'list', session=eid, name=label, query_type='remote') 

843 if len(ins) != 0: 

844 _logger.info("Populating probe insertion with qc") 

845 self._label_probe_qc(probe_out_path, df_units, drift) 

846 _logger.info("Creating SpikeSorting QC plots") 

847 plot_task = ApPlots(ins[0]['id'], session_path=self.session_path, one=self.one) 

848 _ = plot_task.run() 

849 self.plot_tasks.append(plot_task) 

850 

851 plot_task = SpikeSortingPlots(ins[0]['id'], session_path=self.session_path, one=self.one) 

852 _ = plot_task.run(collection=str(probe_out_path.relative_to(self.session_path))) 

853 self.plot_tasks.append(plot_task) 

854 

855 resolved = ins[0].get('json', {'temp': 0}).get('extended_qc', {'temp': 0}). \ 

856 get('alignment_resolved', False) 

857 if resolved: 

858 chns = np.load(probe_out_path.joinpath('channels.localCoordinates.npy')) 

859 out = get_aligned_channels(ins[0], chns, one=self.one, save_dir=probe_out_path) 

860 out_files.extend(out) 

861 self.assert_expected_outputs() 

862 return sorted(list(set(out_files)))