Coverage for ibllib/plots/figures.py: 15%

520 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 15:25 +0000

1""" 

2Module that produces figures, usually for the extraction pipeline 

3""" 

4import logging 

5import time 

6from pathlib import Path 

7import traceback 

8from string import ascii_uppercase 

9 

10import numpy as np 

11import pandas as pd 

12import scipy.signal 

13import matplotlib.pyplot as plt 

14 

15from ibldsp import voltage 

16from ibllib.plots.snapshot import ReportSnapshotProbe, ReportSnapshot 

17from one.api import ONE 

18import one.alf.io as alfio 

19from one.alf.exceptions import ALFObjectNotFound 

20from ibllib.io.video import get_video_frame, url_from_eid 

21from ibllib.oneibl.data_handlers import ExpectedDataset 

22import spikeglx 

23import neuropixel 

24from brainbox.plot import driftmap 

25from brainbox.io.spikeglx import Streamer 

26from brainbox.behavior.dlc import SAMPLING, plot_trace_on_frame, plot_wheel_position, plot_lick_hist, \ 

27 plot_lick_raster, plot_motion_energy_hist, plot_speed_hist, plot_pupil_diameter_hist 

28from brainbox.ephys_plots import image_lfp_spectrum_plot, image_rms_plot, plot_brain_regions 

29from brainbox.io.one import load_spike_sorting_fast 

30from brainbox.behavior import training 

31from iblutil.numerical import ismember 

32from ibllib.plots.misc import Density 

33 

34 

35logger = logging.getLogger(__name__) 

36 

37 

38def set_axis_label_size(ax, labels=14, ticklabels=12, title=14, cmap=False): 

39 """ 

40 Function to normalise size of all axis labels 

41 :param ax: 

42 :param labels: 

43 :param ticklabels: 

44 :param title: 

45 :param cmap: 

46 :return: 

47 """ 

48 

49 ax.xaxis.get_label().set_fontsize(labels) 

50 ax.yaxis.get_label().set_fontsize(labels) 

51 ax.tick_params(labelsize=ticklabels) 

52 ax.title.set_fontsize(title) 

53 

54 if cmap: 

55 cbar = ax.images[-1].colorbar 

56 cbar.ax.tick_params(labelsize=ticklabels) 

57 cbar.ax.yaxis.get_label().set_fontsize(labels) 

58 

59 

60def remove_axis_outline(ax): 

61 """ 

62 Function to remove outline of empty axis 

63 :param ax: 

64 :return: 

65 """ 

66 ax.get_xaxis().set_visible(False) 

67 ax.get_yaxis().set_visible(False) 

68 ax.spines['right'].set_visible(False) 

69 ax.spines['top'].set_visible(False) 

70 ax.spines['bottom'].set_visible(False) 

71 ax.spines['left'].set_visible(False) 

72 

73 

74class BehaviourPlots(ReportSnapshot): 

75 """Behavioural plots.""" 

76 

77 @property 

78 def signature(self): 

79 signature = { 

80 'input_files': [ 

81 ('*trials.table.pqt', self.trials_collection, True), 

82 ], 

83 'output_files': [ 

84 ('psychometric_curve.png', 'snapshot/behaviour', True), 

85 ('chronometric_curve.png', 'snapshot/behaviour', True), 

86 ('reaction_time_with_trials.png', 'snapshot/behaviour', True) 

87 ] 

88 } 

89 return signature 

90 

91 def __init__(self, eid, session_path=None, one=None, **kwargs): 

92 """ 

93 Generate and upload behaviour plots. 

94 

95 Parameters 

96 ---------- 

97 eid : str, uuid.UUID 

98 An experiment UUID. 

99 session_path : pathlib.Path 

100 A session path. 

101 one : one.api.One 

102 An instance of ONE for registration to Alyx. 

103 trials_collection : str 

104 The location of the trials data (default: 'alf'). 

105 kwargs 

106 Arguments for ReportSnapshot constructor. 

107 """ 

108 self.one = one 

109 self.eid = eid 

110 self.session_path = session_path or self.one.eid2path(self.eid) 

111 self.trials_collection = kwargs.pop('task_collection', 'alf') 

112 super(BehaviourPlots, self).__init__(self.session_path, self.eid, one=self.one, 

113 **kwargs) 

114 # Output directory should mirror trials collection, sans 'alf' part 

115 self.output_directory = self.session_path.joinpath( 

116 'snapshot', 'behaviour', self.trials_collection.removeprefix('alf').strip('/')) 

117 self.output_directory.mkdir(exist_ok=True, parents=True) 

118 

119 def _run(self): 

120 

121 output_files = [] 

122 trials = alfio.load_object(self.session_path.joinpath(self.trials_collection), 'trials') 

123 if self.one: 

124 title = self.one.path2ref(self.session_path, as_dict=False) 

125 else: 

126 title = '_'.join(list(self.session_path.parts[-3:])) 

127 

128 fig, ax = training.plot_psychometric(trials, title=title, figsize=(8, 6)) 

129 set_axis_label_size(ax) 

130 save_path = Path(self.output_directory).joinpath("psychometric_curve.png") 

131 output_files.append(save_path) 

132 fig.savefig(save_path) 

133 plt.close(fig) 

134 

135 fig, ax = training.plot_reaction_time(trials, title=title, figsize=(8, 6)) 

136 set_axis_label_size(ax) 

137 save_path = Path(self.output_directory).joinpath("chronometric_curve.png") 

138 output_files.append(save_path) 

139 fig.savefig(save_path) 

140 plt.close(fig) 

141 

142 fig, ax = training.plot_reaction_time_over_trials(trials, title=title, figsize=(8, 6)) 

143 set_axis_label_size(ax) 

144 save_path = Path(self.output_directory).joinpath("reaction_time_with_trials.png") 

145 output_files.append(save_path) 

146 fig.savefig(save_path) 

147 plt.close(fig) 

148 

149 return output_files 

150 

151 

152# TODO put into histology and alignment pipeline 

153class HistologySlices(ReportSnapshotProbe): 

154 """Plots coronal and sagittal slice showing electrode locations.""" 

155 

156 def _run(self): 

157 

158 assert self.pid 

159 assert self.brain_atlas 

160 

161 output_files = [] 

162 self.histology_status = self.get_histology_status() 

163 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}') 

164 

165 if self.hist_lookup[self.histology_status] > 0: 

166 fig = plt.figure(figsize=(12, 9)) 

167 gs = fig.add_gridspec(2, 2, width_ratios=[.95, .05]) 

168 ax1 = fig.add_subplot(gs[0, 0]) 

169 self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 1, ax=ax1) 

170 ax1.scatter(electrodes['mlapdv'][:, 0] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r') 

171 ax1.set_title(f"{self.pid_label}") 

172 

173 ax2 = fig.add_subplot(gs[1, 0]) 

174 self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 0, ax=ax2) 

175 ax2.scatter(electrodes['mlapdv'][:, 1] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r') 

176 

177 ax3 = fig.add_subplot(gs[:, 1]) 

178 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=ax3, 

179 title=self.histology_status) 

180 

181 save_path = Path(self.output_directory).joinpath("histology_slices.png") 

182 output_files.append(save_path) 

183 fig.savefig(save_path) 

184 plt.close(fig) 

185 

186 return output_files 

187 

188 def get_probe_signature(self): 

189 input_signature = [('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False), 

190 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False), 

191 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)] 

192 output_signature = [('histology_slices.png', f'snapshot/{self.pname}', True)] 

193 self.signature = {'input_files': input_signature, 'output_files': output_signature} 

194 

195 

196class LfpPlots(ReportSnapshotProbe): 

197 """ 

198 Plots LFP spectrum and LFP RMS plots 

199 """ 

200 

201 def _run(self): 

202 

203 assert self.pid 

204 

205 output_files = [] 

206 

207 if self.location != 'server': 

208 self.histology_status = self.get_histology_status() 

209 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}') 

210 

211 # lfp spectrum 

212 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

213 lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysSpectralDensityLF', 

214 namespace='iblqc') 

215 _, _, _ = image_lfp_spectrum_plot(lfp.power, lfp.freqs, clim=[-65, -95], fig_kwargs={'figsize': (8, 6)}, ax=axs[0], 

216 display=True, title=f"{self.pid_label}") 

217 set_axis_label_size(axs[0], cmap=True) 

218 if self.histology_status: 

219 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1], 

220 title=self.histology_status) 

221 set_axis_label_size(axs[1]) 

222 else: 

223 remove_axis_outline(axs[1]) 

224 

225 save_path = Path(self.output_directory).joinpath("lfp_spectrum.png") 

226 output_files.append(save_path) 

227 fig.savefig(save_path) 

228 plt.close(fig) 

229 

230 # lfp rms 

231 # TODO need to figure out the clim range 

232 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

233 lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsLF', namespace='iblqc') 

234 _, _, _ = image_rms_plot(lfp.rms, lfp.timestamps, median_subtract=False, band='LFP', clim=[-35, -45], ax=axs[0], 

235 cmap='inferno', fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}") 

236 set_axis_label_size(axs[0], cmap=True) 

237 if self.histology_status: 

238 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1], 

239 title=self.histology_status) 

240 set_axis_label_size(axs[1]) 

241 else: 

242 remove_axis_outline(axs[1]) 

243 

244 save_path = Path(self.output_directory).joinpath("lfp_rms.png") 

245 output_files.append(save_path) 

246 fig.savefig(save_path) 

247 plt.close(fig) 

248 

249 return output_files 

250 

251 def get_probe_signature(self): 

252 input_signature = [('_iblqc_ephysTimeRmsLF.rms.npy', f'raw_ephys_data/{self.pname}', True), 

253 ('_iblqc_ephysTimeRmsLF.timestamps.npy', f'raw_ephys_data/{self.pname}', True), 

254 ('_iblqc_ephysSpectralDensityLF.freqs.npy', f'raw_ephys_data/{self.pname}', True), 

255 ('_iblqc_ephysSpectralDensityLF.power.npy', f'raw_ephys_data/{self.pname}', True), 

256 ('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False), 

257 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False), 

258 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)] 

259 output_signature = [('lfp_spectrum.png', f'snapshot/{self.pname}', True), 

260 ('lfp_rms.png', f'snapshot/{self.pname}', True)] 

261 self.signature = {'input_files': input_signature, 'output_files': output_signature} 

262 

263 

264class ApPlots(ReportSnapshotProbe): 

265 """ 

266 Plots AP RMS plots 

267 """ 

268 

269 def _run(self): 

270 

271 assert self.pid 

272 

273 output_files = [] 

274 

275 if self.location != 'server': 

276 self.histology_status = self.get_histology_status() 

277 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}') 

278 

279 # TODO need to figure out the clim range 

280 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

281 ap = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsAP', namespace='iblqc') 

282 _, _, _ = image_rms_plot(ap.rms, ap.timestamps, median_subtract=False, band='AP', ax=axs[0], 

283 fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}") 

284 set_axis_label_size(axs[0], cmap=True) 

285 if self.histology_status: 

286 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1], 

287 title=self.histology_status) 

288 set_axis_label_size(axs[1]) 

289 else: 

290 remove_axis_outline(axs[1]) 

291 

292 save_path = Path(self.output_directory).joinpath("ap_rms.png") 

293 output_files.append(save_path) 

294 fig.savefig(save_path) 

295 plt.close(fig) 

296 

297 return output_files 

298 

299 def get_probe_signature(self): 

300 input_signature = [('_iblqc_ephysTimeRmsAP.rms.npy', f'raw_ephys_data/{self.pname}', True), 

301 ('_iblqc_ephysTimeRmsAP.timestamps.npy', f'raw_ephys_data/{self.pname}', True), 

302 ('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False), 

303 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False), 

304 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)] 

305 output_signature = [('ap_rms.png', f'snapshot/{self.pname}', True)] 

306 self.signature = {'input_files': input_signature, 'output_files': output_signature} 

307 

308 

309class SpikeSorting(ReportSnapshotProbe): 

310 """ 

311 Plots raw electrophysiology AP band 

312 :param session_path: session path 

313 :param probe_id: str, UUID of the probe insertion for which to create the plot 

314 :param **kwargs: keyword arguments passed to tasks.Task 

315 """ 

316 

317 def _run(self, collection=None): 

318 """runs for initiated PID, streams data, destripe and check bad channels""" 

319 

320 def plot_driftmap(self, spikes, clusters, channels, collection): 

321 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

322 driftmap(spikes.times, spikes.depths, t_bin=0.007, d_bin=10, vmax=0.5, ax=axs[0]) 

323 title_str = f"{self.pid_label}, {collection}, {self.pid} \n " \ 

324 f"{spikes.clusters.size:_} spikes, {clusters.depths.size:_} clusters" 

325 ylim = (0, np.max(channels['axial_um'])) 

326 axs[0].set(ylim=ylim, title=title_str) 

327 run_label = str(Path(collection).relative_to(f'alf/{self.pname}')) 

328 run_label = "ks2matlab" if run_label == '.' else run_label 

329 outfile = self.output_directory.joinpath(f"spike_sorting_raster_{run_label}.png") 

330 set_axis_label_size(axs[0]) 

331 

332 if self.histology_status: 

333 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 

334 brain_regions=self.brain_regions, display=True, ax=axs[1], title=self.histology_status) 

335 axs[1].set(ylim=ylim) 

336 set_axis_label_size(axs[1]) 

337 else: 

338 remove_axis_outline(axs[1]) 

339 

340 fig.savefig(outfile) 

341 plt.close(fig) 

342 

343 return outfile, fig, axs 

344 

345 output_files = [] 

346 if self.location == 'server': 

347 assert collection 

348 spikes = alfio.load_object(self.session_path.joinpath(collection), 'spikes') 

349 clusters = alfio.load_object(self.session_path.joinpath(collection), 'clusters') 

350 channels = alfio.load_object(self.session_path.joinpath(collection), 'channels') 

351 channels['axial_um'] = channels['localCoordinates'][:, 1] 

352 

353 out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection) 

354 output_files.append(out) 

355 

356 else: 

357 self.histology_status = self.get_histology_status() 

358 all_here, output_files = self.assert_expected(self.output_files, silent=True) 

359 spike_sorting_runs = self.one.list_datasets(self.eid, filename='spikes.times.npy', collection=f'alf/{self.pname}*') 

360 if all_here and len(output_files) == len(spike_sorting_runs): 

361 return output_files 

362 logger.info(self.output_directory) 

363 for run in spike_sorting_runs: 

364 collection = str(Path(run).parent.as_posix()) 

365 spikes, clusters, channels = load_spike_sorting_fast( 

366 eid=self.eid, probe=self.pname, one=self.one, nested=False, collection=collection, 

367 dataset_types=['spikes.depths'], brain_regions=self.brain_regions) 

368 

369 if 'atlas_id' not in channels.keys(): 

370 channels = self.get_channels('channels', collection) 

371 

372 out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection) 

373 output_files.append(out) 

374 

375 return output_files 

376 

377 def get_probe_signature(self): 

378 input_signature = [('spikes.times.npy', f'alf/{self.pname}*', True), 

379 ('spikes.amps.npy', f'alf/{self.pname}*', True), 

380 ('spikes.depths.npy', f'alf/{self.pname}*', True), 

381 ('clusters.depths.npy', f'alf/{self.pname}*', True), 

382 ('channels.localCoordinates.npy', f'alf/{self.pname}*', False), 

383 ('channels.mlapdv.npy', f'alf/{self.pname}*', False), 

384 ('channels.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}*', False)] 

385 output_signature = [('spike_sorting_raster*.png', f'snapshot/{self.pname}', True)] 

386 self.signature = {'input_files': input_signature, 'output_files': output_signature} 

387 

388 def get_signatures(self, **kwargs): 

389 files_spikes = Path(self.session_path).joinpath('alf').rglob('spikes.times.npy') 

390 folder_probes = [f.parent for f in files_spikes] 

391 full_input_files = [] 

392 for sig in self.signature['input_files']: 

393 for folder in folder_probes: 

394 full_input_files.append((sig[0], str(folder.relative_to(self.session_path)), sig[2])) 

395 if len(full_input_files) != 0: 

396 self.input_files = full_input_files 

397 else: 

398 self.input_files = self.signature['input_files'] 

399 self.output_files = self.signature['output_files'] 

400 self.input_files = [ExpectedDataset.input(*i) for i in self.input_files] 

401 self.output_files = [ExpectedDataset.output(*i) for i in self.output_files] 

402 

403 

404class BadChannelsAp(ReportSnapshotProbe): 

405 """ 

406 Plots raw electrophysiology AP band 

407 task = BadChannelsAp(pid, one=one=one) 

408 :param session_path: session path 

409 :param probe_id: str, UUID of the probe insertion for which to create the plot 

410 :param **kwargs: keyword arguments passed to tasks.Task 

411 """ 

412 

413 def get_probe_signature(self): 

414 pname = self.pname 

415 input_signature = [('*ap.meta', f'raw_ephys_data/{pname}', True), 

416 ('*ap.ch', f'raw_ephys_data/{pname}', False)] 

417 output_signature = [('raw_ephys_bad_channels.png', f'snapshot/{pname}', True), 

418 ('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True), 

419 ('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True), 

420 ('raw_ephys_bad_channels_destripe.png', f'snapshot/{pname}', True), 

421 ('raw_ephys_bad_channels_difference.png', f'snapshot/{pname}', True), 

422 ] 

423 self.signature = {'input_files': input_signature, 'output_files': output_signature} 

424 

425 def _run(self): 

426 """runs for initiated PID, streams data, destripe and check bad channels""" 

427 assert self.pid 

428 self.eqcs = [] 

429 T0 = 60 * 30 

430 SNAPSHOT_LABEL = "raw_ephys_bad_channels" 

431 output_files = list(self.output_directory.glob(f'{SNAPSHOT_LABEL}*')) 

432 if len(output_files) == 4: 

433 return output_files 

434 

435 self.output_directory.mkdir(exist_ok=True, parents=True) 

436 

437 if self.location != 'server': 

438 self.histology_status = self.get_histology_status() 

439 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}') 

440 

441 if 'atlas_id' in electrodes.keys(): 

442 electrodes['ibr'] = ismember(electrodes['atlas_id'], self.brain_regions.id)[1] 

443 electrodes['acronym'] = self.brain_regions.acronym[electrodes['ibr']] 

444 electrodes['name'] = self.brain_regions.name[electrodes['ibr']] 

445 electrodes['title'] = self.histology_status 

446 else: 

447 electrodes = None 

448 

449 nsecs = 1 

450 sr = Streamer(pid=self.pid, one=self.one, remove_cached=False, typ='ap') 

451 s0 = T0 * sr.fs 

452 tsel = slice(int(s0), int(s0) + int(nsecs * sr.fs)) 

453 # Important: remove sync channel from raw data, and transpose 

454 raw = sr[tsel, :-sr.nsync].T 

455 

456 else: 

457 electrodes = None 

458 ap_file = next(self.session_path.joinpath('raw_ephys_data', self.pname).glob('*ap.*bin'), None) 

459 if ap_file is not None: 

460 sr = spikeglx.Reader(ap_file) 

461 # If T0 is greater than recording length, take 500 sec before end 

462 if sr.rl < T0: 

463 T0 = int(sr.rl - 500) 

464 raw = sr[int((sr.fs * T0)):int((sr.fs * (T0 + 1))), :-sr.nsync].T 

465 else: 

466 return [] 

467 

468 if sr.meta.get('NP2.4_shank', None) is not None: 

469 h = neuropixel.trace_header(sr.major_version, nshank=4) 

470 h = neuropixel.split_trace_header(h, shank=int(sr.meta.get('NP2.4_shank'))) 

471 else: 

472 h = neuropixel.trace_header(sr.major_version, nshank=np.unique(sr.geometry['shank']).size) 

473 

474 channel_labels, channel_features = voltage.detect_bad_channels(raw, sr.fs) 

475 _, eqcs, output_files = ephys_bad_channels( 

476 raw=raw, fs=sr.fs, channel_labels=channel_labels, channel_features=channel_features, h=h, channels=electrodes, 

477 title=SNAPSHOT_LABEL, destripe=True, save_dir=self.output_directory, br=self.brain_regions, pid_info=self.pid_label) 

478 self.eqcs = eqcs 

479 return output_files 

480 

481 

482def ephys_bad_channels(raw, fs, channel_labels, channel_features, h=None, channels=None, title="ephys_bad_channels", 

483 save_dir=None, destripe=False, eqcs=None, br=None, pid_info=None, plot_backend='matplotlib'): 

484 nc, ns = raw.shape 

485 rl = ns / fs 

486 

487 def gain2level(gain): 

488 return 10 ** (gain / 20) * 4 * np.array([-1, 1]) 

489 

490 if fs >= 2600: # AP band 

491 ylim_rms = [0, 100] 

492 ylim_psd_hf = [0, 0.1] 

493 eqc_xrange = [450, 500] 

494 butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'} 

495 eqc_gain = - 90 

496 eqc_levels = gain2level(eqc_gain) 

497 else: 

498 # we are working with the LFP 

499 ylim_rms = [0, 1000] 

500 ylim_psd_hf = [0, 1] 

501 eqc_xrange = [450, 950] 

502 butter_kwargs = {'N': 3, 'Wn': np.array([2, 125]) / fs * 2, 'btype': 'bandpass'} 

503 eqc_gain = - 78 

504 eqc_levels = gain2level(eqc_gain) 

505 

506 inoisy = np.where(channel_labels == 2)[0] 

507 idead = np.where(channel_labels == 1)[0] 

508 ioutside = np.where(channel_labels == 3)[0] 

509 

510 # display voltage traces 

511 eqcs = [] if eqcs is None else eqcs 

512 # butterworth, for display only 

513 sos = scipy.signal.butter(**butter_kwargs, output='sos') 

514 butt = scipy.signal.sosfiltfilt(sos, raw) 

515 

516 if plot_backend == 'matplotlib': 

517 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

518 eqcs.append(Density(butt, fs=fs, taxis=1, ax=axs[0], title='highpass', vmin=eqc_levels[0], vmax=eqc_levels[1])) 

519 

520 if destripe: 

521 dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels) 

522 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

523 eqcs.append(Density( 

524 dest, fs=fs, taxis=1, ax=axs[0], title='destripe', vmin=eqc_levels[0], vmax=eqc_levels[1])) 

525 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

526 eqcs.append(Density((butt - dest), fs=fs, taxis=1, ax=axs[0], title='difference', vmin=eqc_levels[0], 

527 vmax=eqc_levels[1])) 

528 

529 for eqc in eqcs: 

530 y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500)) 

531 eqc.ax.scatter(x.flatten(), y.flatten(), c='goldenrod', s=4) 

532 y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500)) 

533 eqc.ax.scatter(x.flatten(), y.flatten(), c='r', s=4) 

534 y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500)) 

535 eqc.ax.scatter(x.flatten(), y.flatten(), c='b', s=4) 

536 

537 eqc.ax.set_xlim(*eqc_xrange) 

538 eqc.ax.set_ylim(0, nc) 

539 eqc.ax.set_ylabel('Channel index') 

540 eqc.ax.set_title(f'{pid_info}_{eqc.title}') 

541 set_axis_label_size(eqc.ax) 

542 

543 ax = eqc.figure.axes[1] 

544 if channels is not None: 

545 chn_title = channels.get('title', None) 

546 plot_brain_regions(channels['atlas_id'], brain_regions=br, display=True, ax=ax, 

547 title=chn_title) 

548 set_axis_label_size(ax) 

549 else: 

550 remove_axis_outline(ax) 

551 else: 

552 from viewspikes.gui import viewephys # noqa 

553 eqcs.append(viewephys(butt, fs=fs, channels=channels, title='highpass', br=br)) 

554 

555 if destripe: 

556 dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels) 

557 eqcs.append(viewephys(dest, fs=fs, channels=channels, title='destripe', br=br)) 

558 eqcs.append(viewephys((butt - dest), fs=fs, channels=channels, title='difference', br=br)) 

559 

560 for eqc in eqcs: 

561 y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500)) 

562 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(164, 142, 35), label='outside') 

563 y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500)) 

564 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(255, 0, 0), label='noisy') 

565 y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500)) 

566 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(0, 0, 255), label='dead') 

567 

568 eqcs[0].ctrl.set_gain(eqc_gain) 

569 eqcs[0].resize(1960, 1200) 

570 eqcs[0].viewBox_seismic.setXRange(*eqc_xrange) 

571 eqcs[0].viewBox_seismic.setYRange(0, nc) 

572 eqcs[0].ctrl.propagate() 

573 

574 # display features 

575 fig, axs = plt.subplots(2, 2, sharex=True, figsize=[16, 9], tight_layout=True) 

576 fig.suptitle(title) 

577 axs[0, 0].plot(channel_features['rms_raw'] * 1e6) 

578 axs[0, 0].set(title='rms', xlabel='channel number', ylabel='rms (uV)', ylim=ylim_rms) 

579 

580 axs[1, 0].plot(channel_features['psd_hf']) 

581 axs[1, 0].plot(inoisy, np.minimum(channel_features['psd_hf'][inoisy], 0.0999), 'xr') 

582 axs[1, 0].set(title='PSD above 80% Nyquist', xlabel='channel number', ylabel='PSD (uV ** 2 / Hz)', ylim=ylim_psd_hf) 

583 axs[1, 0].legend = ['psd', 'noisy'] 

584 

585 axs[0, 1].plot(channel_features['xcor_hf']) 

586 axs[0, 1].plot(channel_features['xcor_lf']) 

587 

588 axs[0, 1].plot(idead, channel_features['xcor_hf'][idead], 'xb') 

589 axs[0, 1].plot(ioutside, channel_features['xcor_lf'][ioutside], 'xy') 

590 axs[0, 1].set(title='Similarity', xlabel='channel number', ylabel='', ylim=[-1.5, 0.5]) 

591 axs[0, 1].legend(['detrend', 'trend', 'dead', 'outside']) 

592 

593 fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz 

594 axs[1, 1].imshow(20 * np.log10(psd).T, extent=[0, nc - 1, fscale[0], fscale[-1]], origin='lower', aspect='auto', 

595 vmin=-50, vmax=-20) 

596 axs[1, 1].set(title='PSD', xlabel='channel number', ylabel="Frequency (Hz)") 

597 axs[1, 1].plot(idead, idead * 0 + fs / 4, 'xb') 

598 axs[1, 1].plot(inoisy, inoisy * 0 + fs / 4, 'xr') 

599 axs[1, 1].plot(ioutside, ioutside * 0 + fs / 4, 'xy') 

600 

601 if save_dir is not None: 

602 output_files = [Path(save_dir).joinpath(f"{title}.png")] 

603 fig.savefig(output_files[0]) 

604 for eqc in eqcs: 

605 if plot_backend == 'matplotlib': 

606 output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.title}.png")) 

607 eqc.figure.savefig(str(output_files[-1])) 

608 else: 

609 output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.windowTitle()}.png")) 

610 eqc.grab().save(str(output_files[-1])) 

611 return fig, eqcs, output_files 

612 else: 

613 return fig, eqcs 

614 

615 

616def raw_destripe(raw, fs, t0, i_plt, n_plt, 

617 fig=None, axs=None, savedir=None, detect_badch=True, 

618 SAMPLE_SKIP=200, DISPLAY_TIME=0.05, N_CHAN=384, 

619 MIN_X=-0.00011, MAX_X=0.00011): 

620 ''' 

621 :param raw: raw ephys data, Ns x Nc, x-axis: time (s), y-axis: channel 

622 :param fs: sampling freq (Hz) of the raw ephys data 

623 :param t0: time (s) of ephys sample beginning from session start 

624 :param i_plt: increment of plot to display image one (start from 0, has to be < n_plt) 

625 :param n_plt: total number of subplot on figure 

626 :param fig: figure handle 

627 :param axs: axis handle 

628 :param savedir: filename, including directory, to save figure to 

629 :param detect_badch: boolean, to detect or not bad channels 

630 :param SAMPLE_SKIP: number of samples to skip at origin of ephsy sample for display 

631 :param DISPLAY_TIME: time (s) to display 

632 :param N_CHAN: number of expected channels on the probe 

633 :param MIN_X: max voltage for color range 

634 :param MAX_X: min voltage for color range 

635 :return: fig, axs 

636 ''' 

637 

638 # Import 

639 from ibldsp import voltage 

640 from ibllib.plots import Density 

641 

642 # Init fig 

643 if fig is None or axs is None: 

644 fig, axs = plt.subplots(nrows=1, ncols=n_plt, figsize=(14, 5), gridspec_kw={'width_ratios': 4 * n_plt}) 

645 

646 if i_plt > len(axs) - 1: # Error 

647 raise ValueError(f'The given increment of subplot ({i_plt + 1}) ' 

648 f'is larger than the total number of subplots ({len(axs)})') 

649 

650 [nc, ns] = raw.shape 

651 if nc == N_CHAN: 

652 destripe = voltage.destripe(raw, fs=fs) 

653 X = destripe[:, :int(DISPLAY_TIME * fs)].T 

654 Xs = X[SAMPLE_SKIP:].T # Remove artifact at beginning 

655 Tplot = Xs.shape[1] / fs 

656 

657 # PLOT RAW DATA 

658 d = Density(-Xs, fs=fs, taxis=1, ax=axs[i_plt], vmin=MIN_X, vmax=MAX_X) # noqa 

659 axs[i_plt].set_ylabel('') 

660 axs[i_plt].set_xlim((0, Tplot * 1e3)) 

661 axs[i_plt].set_ylim((0, nc)) 

662 

663 # Init title 

664 title_plt = f't0 = {int(t0 / 60)} min' 

665 

666 if detect_badch: 

667 # Detect and remove bad channels prior to spike detection 

668 labels, xfeats = voltage.detect_bad_channels(raw, fs) 

669 idx_badchan = np.where(labels != 0)[0] 

670 # Plot bad channels on raw data 

671 x, y = np.meshgrid(idx_badchan, np.linspace(0, Tplot * 1e3, 20)) 

672 axs[i_plt].plot(y.flatten(), x.flatten(), '.k', markersize=1) 

673 # Append title 

674 title_plt += f', n={len(idx_badchan)} bad ch' 

675 

676 # Set title 

677 axs[i_plt].title.set_text(title_plt) 

678 

679 else: 

680 axs[i_plt].title.set_text(f'CANNOT DESTRIPE, N CHAN = {nc}') 

681 

682 # Amend some axis style 

683 if i_plt > 0: 

684 axs[i_plt].set_yticklabels('') 

685 

686 # Fig layout 

687 fig.tight_layout() 

688 if savedir is not None: 

689 fig.savefig(fname=savedir) 

690 

691 return fig, axs 

692 

693 

694def dlc_qc_plot(session_path, one=None, device_collection='raw_video_data', 

695 cameras=('left', 'right', 'body'), trials_collection='alf'): 

696 """ 

697 Creates DLC QC plot. 

698 Data is searched first locally, then on Alyx. Panels that lack required data are skipped. 

699 

700 Required data to create all panels 

701 'raw_video_data/_iblrig_bodyCamera.raw.mp4', 

702 'raw_video_data/_iblrig_leftCamera.raw.mp4', 

703 'raw_video_data/_iblrig_rightCamera.raw.mp4', 

704 'alf/_ibl_bodyCamera.dlc.pqt', 

705 'alf/_ibl_leftCamera.dlc.pqt', 

706 'alf/_ibl_rightCamera.dlc.pqt', 

707 'alf/_ibl_bodyCamera.times.npy', 

708 'alf/_ibl_leftCamera.times.npy', 

709 'alf/_ibl_rightCamera.times.npy', 

710 'alf/_ibl_leftCamera.features.pqt', 

711 'alf/_ibl_rightCamera.features.pqt', 

712 'alf/rightROIMotionEnergy.position.npy', 

713 'alf/leftROIMotionEnergy.position.npy', 

714 'alf/bodyROIMotionEnergy.position.npy', 

715 'alf/_ibl_trials.choice.npy', 

716 'alf/_ibl_trials.feedbackType.npy', 

717 'alf/_ibl_trials.feedback_times.npy', 

718 'alf/_ibl_trials.stimOn_times.npy', 

719 'alf/_ibl_wheel.position.npy', 

720 'alf/_ibl_wheel.timestamps.npy', 

721 'alf/licks.times.npy', 

722 

723 :params session_path: Path to session data on disk 

724 :params one: ONE instance, if None is given, default ONE is instantiated 

725 :returns: Matplotlib figure 

726 """ 

727 

728 one = one or ONE() 1b

729 # hack for running on cortexlab local server 

730 if one.alyx.base_url == 'https://alyx.cortexlab.net': 1b

731 one = ONE(base_url='https://alyx.internationalbrainlab.org') 

732 data = {} 1b

733 session_path = Path(session_path) 1b

734 

735 # Load data for each camera 

736 for cam in cameras: 1b

737 # Load a single frame for each video 

738 # Check if video data is available locally,if yes, load a single frame 

739 video_path = session_path.joinpath(device_collection, f'_iblrig_{cam}Camera.raw.mp4') 1b

740 if video_path.exists(): 1b

741 data[f'{cam}_frame'] = get_video_frame(video_path, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0] 

742 # If not, try to stream a frame (try three times) 

743 else: 

744 try: 1b

745 video_url = url_from_eid(one.path2eid(session_path), one=one)[cam] 1b

746 for tries in range(3): 

747 try: 

748 data[f'{cam}_frame'] = get_video_frame(video_url, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0] 

749 break 

750 except Exception: 

751 if tries < 2: 

752 tries += 1 

753 logger.info(f"Streaming {cam} video failed, retrying x{tries}") 

754 time.sleep(30) 

755 else: 

756 logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.") 

757 data[f'{cam}_frame'] = None 

758 except KeyError: 1b

759 logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.") 1b

760 data[f'{cam}_frame'] = None 1b

761 # Other camera associated data 

762 for feat in ['dlc', 'times', 'features', 'ROIMotionEnergy']: 1b

763 # Check locally first, then try to load from alyx, if nothing works, set to None 

764 if feat == 'features' and cam == 'body': # this doesn't exist for body cam 1b

765 continue 1b

766 local_file = list(session_path.joinpath('alf').glob(f'*{cam}Camera.{feat}*')) 1b

767 if len(local_file) > 0: 1b

768 data[f'{cam}_{feat}'] = alfio.load_file_content(local_file[0]) 

769 else: 

770 alyx_ds = [ds for ds in one.list_datasets(one.path2eid(session_path)) if f'{cam}Camera.{feat}' in ds] 1b

771 if len(alyx_ds) > 0: 1b

772 data[f'{cam}_{feat}'] = one.load_dataset(one.path2eid(session_path), alyx_ds[0]) 

773 else: 

774 logger.warning(f"Could not load _ibl_{cam}Camera.{feat} some plots have to be skipped.") 1b

775 data[f'{cam}_{feat}'] = None 1b

776 # Sometimes there is a file but the object is empty, set to None 

777 if data[f'{cam}_{feat}'] is not None and len(data[f'{cam}_{feat}']) == 0: 1b

778 logger.warning(f"Object loaded from _ibl_{cam}Camera.{feat} is empty, some plots have to be skipped.") 

779 data[f'{cam}_{feat}'] = None 

780 

781 # If we have no frame and/or no DLC and/or no times for all cams, raise an error, something is really wrong 

782 assert any(data[f'{cam}_frame'] is not None for cam in cameras), "No camera data could be loaded, aborting." 1b

783 assert any(data[f'{cam}_dlc'] is not None for cam in cameras), "No DLC data could be loaded, aborting." 

784 assert any(data[f'{cam}_times'] is not None for cam in cameras), "No camera times data could be loaded, aborting." 

785 

786 # Load session level data 

787 for alf_object, collection in zip(['trials', 'wheel', 'licks'], [trials_collection, trials_collection, 'alf']): 

788 try: 

789 data[f'{alf_object}'] = alfio.load_object(session_path.joinpath(collection), alf_object) # load locally 

790 continue 

791 except ALFObjectNotFound: 

792 pass 

793 try: 

794 # then try from alyx 

795 data[f'{alf_object}'] = one.load_object(one.path2eid(session_path), alf_object, collection=collection) 

796 except ALFObjectNotFound: 

797 logger.warning(f"Could not load {alf_object} object, some plots have to be skipped.") 

798 data[f'{alf_object}'] = None 

799 

800 # Simplify and clean up trials data 

801 if data['trials']: 

802 data['trials'] = pd.DataFrame( 

803 {k: data['trials'][k] for k in ['stimOn_times', 'feedback_times', 'choice', 'feedbackType']}) 

804 # Discard nan events and too long trials 

805 data['trials'] = data['trials'].dropna() 

806 data['trials'] = data['trials'].drop( 

807 data['trials'][(data['trials']['feedback_times'] - data['trials']['stimOn_times']) > 10].index) 

808 

809 # Make a list of panels, if inputs are missing, instead input a text to display 

810 panels = [] 

811 # Panel A, B, C: Trace on frame 

812 for cam in cameras: 

813 if data[f'{cam}_frame'] is not None and data[f'{cam}_dlc'] is not None: 

814 panels.append((plot_trace_on_frame, 

815 {'frame': data[f'{cam}_frame'], 'dlc_df': data[f'{cam}_dlc'], 'cam': cam})) 

816 else: 

817 panels.append((None, f'Data missing\n{cam.capitalize()} cam trace on frame')) 

818 

819 # If trials data is not there, we cannot plot any of the trial average plots, skip all remaining panels 

820 if data['trials'] is None: 

821 panels.extend([(None, 'No trial data,\ncannot compute trial avgs')] * 7) 

822 else: 

823 # Panel D: Motion energy 

824 camera_dict = {} 

825 for cam in cameras: # Remove cameras where we don't have motion energy AND camera times 

826 d = {'motion_energy': data.get(f'{cam}_ROIMotionEnergy'), 'times': data.get(f'{cam}_times')} 

827 if not any(x is None for x in d.values()): 

828 camera_dict[cam] = d 

829 if len(camera_dict) > 0: 

830 panels.append((plot_motion_energy_hist, {'camera_dict': camera_dict, 'trials_df': data['trials']})) 

831 else: 

832 panels.append((None, 'Data missing\nMotion energy')) 

833 

834 # Panel E: Wheel position 

835 if data['wheel']: 

836 panels.append((plot_wheel_position, {'wheel_position': data['wheel'].position, 

837 'wheel_time': data['wheel'].timestamps, 

838 'trials_df': data['trials']})) 

839 else: 

840 panels.append((None, 'Data missing\nWheel position')) 

841 

842 # Panel F, G: Paw speed and nose speed 

843 # Try if all data is there for left cam first, otherwise right 

844 for cam in ['left', 'right']: 

845 fail = False 

846 if (data[f'{cam}_dlc'] is not None and data[f'{cam}_times'] is not None 

847 and len(data[f'{cam}_times']) >= len(data[f'{cam}_dlc'])): 

848 break 

849 fail = True 

850 if not fail: 

851 paw = 'r' if cam == 'left' else 'l' 

852 panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'], 

853 'trials_df': data['trials'], 'feature': f'paw_{paw}', 'cam': cam})) 

854 panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'], 

855 'trials_df': data['trials'], 'feature': 'nose_tip', 'legend': False, 

856 'cam': cam})) 

857 else: 

858 panels.extend([(None, 'Data missing or corrupt\nSpeed histograms')] * 2) 

859 

860 # Panel H and I: Lick plots 

861 if data['licks'] and data['licks'].times.shape[0] > 0: 

862 panels.append((plot_lick_hist, {'lick_times': data['licks'].times, 'trials_df': data['trials']})) 

863 panels.append((plot_lick_raster, {'lick_times': data['licks'].times, 'trials_df': data['trials']})) 

864 else: 

865 panels.extend([(None, 'Data missing\nLicks plots') for i in range(2)]) 

866 

867 # Panel J: pupil plot 

868 # Try if all data is there for left cam first, otherwise right 

869 for cam in ['left', 'right']: 

870 fail = False 

871 if (data.get(f'{cam}_times') is not None and data.get(f'{cam}_features') is not None 

872 and len(data[f'{cam}_times']) >= len(data[f'{cam}_features']) 

873 and not np.all(np.isnan(data[f'{cam}_features'].pupilDiameter_smooth))): 

874 break 

875 fail = True 

876 if not fail: 

877 panels.append((plot_pupil_diameter_hist, 

878 {'pupil_diameter': data[f'{cam}_features'].pupilDiameter_smooth, 

879 'cam_times': data[f'{cam}_times'], 'trials_df': data['trials'], 'cam': cam})) 

880 else: 

881 panels.append((None, 'Data missing or corrupt\nPupil diameter')) 

882 

883 # Plotting 

884 plt.rcParams.update({'font.size': 10}) 

885 fig = plt.figure(figsize=(17, 10)) 

886 for i, panel in enumerate(panels): 

887 ax = plt.subplot(2, 5, i + 1) 

888 ax.text(-0.1, 1.15, ascii_uppercase[i], transform=ax.transAxes, fontsize=16, fontweight='bold') 

889 # Check if there was in issue with inputs, if yes, print the respective text 

890 if panel[0] is None: 

891 ax.text(.5, .5, panel[1], color='r', fontweight='bold', fontsize=12, horizontalalignment='center', 

892 verticalalignment='center', transform=ax.transAxes) 

893 plt.axis('off') 

894 else: 

895 try: 

896 panel[0](**panel[1]) 

897 except Exception: 

898 logger.error(f'Error in {panel[0].__name__}\n' + traceback.format_exc()) 

899 ax.text(.5, .5, f'Error while plotting\n{panel[0].__name__}', color='r', fontweight='bold', 

900 fontsize=12, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes) 

901 plt.axis('off') 

902 plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 

903 

904 return fig