Coverage for ibllib/plots/figures.py: 15%

517 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1""" 

2Module that produces figures, usually for the extraction pipeline 

3""" 

4import logging 

5import time 

6from pathlib import Path 

7import traceback 

8from string import ascii_uppercase 

9 

10import numpy as np 

11import pandas as pd 

12import scipy.signal 

13import matplotlib.pyplot as plt 

14 

15from ibldsp import voltage 

16from ibllib.plots.snapshot import ReportSnapshotProbe, ReportSnapshot 

17from one.api import ONE 

18import one.alf.io as alfio 

19from one.alf.exceptions import ALFObjectNotFound 

20from ibllib.io.video import get_video_frame, url_from_eid 

21import spikeglx 

22import neuropixel 

23from brainbox.plot import driftmap 

24from brainbox.io.spikeglx import Streamer 

25from brainbox.behavior.dlc import SAMPLING, plot_trace_on_frame, plot_wheel_position, plot_lick_hist, \ 

26 plot_lick_raster, plot_motion_energy_hist, plot_speed_hist, plot_pupil_diameter_hist 

27from brainbox.ephys_plots import image_lfp_spectrum_plot, image_rms_plot, plot_brain_regions 

28from brainbox.io.one import load_spike_sorting_fast 

29from brainbox.behavior import training 

30from iblutil.numerical import ismember 

31from ibllib.plots.misc import Density 

32 

33 

34logger = logging.getLogger(__name__) 

35 

36 

37def set_axis_label_size(ax, labels=14, ticklabels=12, title=14, cmap=False): 

38 """ 

39 Function to normalise size of all axis labels 

40 :param ax: 

41 :param labels: 

42 :param ticklabels: 

43 :param title: 

44 :param cmap: 

45 :return: 

46 """ 

47 

48 ax.xaxis.get_label().set_fontsize(labels) 

49 ax.yaxis.get_label().set_fontsize(labels) 

50 ax.tick_params(labelsize=ticklabels) 

51 ax.title.set_fontsize(title) 

52 

53 if cmap: 

54 cbar = ax.images[-1].colorbar 

55 cbar.ax.tick_params(labelsize=ticklabels) 

56 cbar.ax.yaxis.get_label().set_fontsize(labels) 

57 

58 

59def remove_axis_outline(ax): 

60 """ 

61 Function to remove outline of empty axis 

62 :param ax: 

63 :return: 

64 """ 

65 ax.get_xaxis().set_visible(False) 

66 ax.get_yaxis().set_visible(False) 

67 ax.spines['right'].set_visible(False) 

68 ax.spines['top'].set_visible(False) 

69 ax.spines['bottom'].set_visible(False) 

70 ax.spines['left'].set_visible(False) 

71 

72 

73class BehaviourPlots(ReportSnapshot): 

74 """Behavioural plots.""" 

75 

76 @property 

77 def signature(self): 

78 signature = { 

79 'input_files': [ 

80 ('*trials.table.pqt', self.trials_collection, True), 

81 ], 

82 'output_files': [ 

83 ('psychometric_curve.png', 'snapshot/behaviour', True), 

84 ('chronometric_curve.png', 'snapshot/behaviour', True), 

85 ('reaction_time_with_trials.png', 'snapshot/behaviour', True) 

86 ] 

87 } 

88 return signature 

89 

90 def __init__(self, eid, session_path=None, one=None, **kwargs): 

91 """ 

92 Generate and upload behaviour plots. 

93 

94 Parameters 

95 ---------- 

96 eid : str, uuid.UUID 

97 An experiment UUID. 

98 session_path : pathlib.Path 

99 A session path. 

100 one : one.api.One 

101 An instance of ONE for registration to Alyx. 

102 trials_collection : str 

103 The location of the trials data (default: 'alf'). 

104 kwargs 

105 Arguments for ReportSnapshot constructor. 

106 """ 

107 self.one = one 

108 self.eid = eid 

109 self.session_path = session_path or self.one.eid2path(self.eid) 

110 self.trials_collection = kwargs.pop('task_collection', 'alf') 

111 super(BehaviourPlots, self).__init__(self.session_path, self.eid, one=self.one, 

112 **kwargs) 

113 # Output directory should mirror trials collection, sans 'alf' part 

114 self.output_directory = self.session_path.joinpath( 

115 'snapshot', 'behaviour', self.trials_collection.removeprefix('alf').strip('/')) 

116 self.output_directory.mkdir(exist_ok=True, parents=True) 

117 

118 def _run(self): 

119 

120 output_files = [] 

121 trials = alfio.load_object(self.session_path.joinpath(self.trials_collection), 'trials') 

122 if self.one: 

123 title = self.one.path2ref(self.session_path, as_dict=False) 

124 else: 

125 title = '_'.join(list(self.session_path.parts[-3:])) 

126 

127 fig, ax = training.plot_psychometric(trials, title=title, figsize=(8, 6)) 

128 set_axis_label_size(ax) 

129 save_path = Path(self.output_directory).joinpath("psychometric_curve.png") 

130 output_files.append(save_path) 

131 fig.savefig(save_path) 

132 plt.close(fig) 

133 

134 fig, ax = training.plot_reaction_time(trials, title=title, figsize=(8, 6)) 

135 set_axis_label_size(ax) 

136 save_path = Path(self.output_directory).joinpath("chronometric_curve.png") 

137 output_files.append(save_path) 

138 fig.savefig(save_path) 

139 plt.close(fig) 

140 

141 fig, ax = training.plot_reaction_time_over_trials(trials, title=title, figsize=(8, 6)) 

142 set_axis_label_size(ax) 

143 save_path = Path(self.output_directory).joinpath("reaction_time_with_trials.png") 

144 output_files.append(save_path) 

145 fig.savefig(save_path) 

146 plt.close(fig) 

147 

148 return output_files 

149 

150 

151# TODO put into histology and alignment pipeline 

152class HistologySlices(ReportSnapshotProbe): 

153 """Plots coronal and sagittal slice showing electrode locations.""" 

154 

155 def _run(self): 

156 

157 assert self.pid 

158 assert self.brain_atlas 

159 

160 output_files = [] 

161 self.histology_status = self.get_histology_status() 

162 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}') 

163 

164 if self.hist_lookup[self.histology_status] > 0: 

165 fig = plt.figure(figsize=(12, 9)) 

166 gs = fig.add_gridspec(2, 2, width_ratios=[.95, .05]) 

167 ax1 = fig.add_subplot(gs[0, 0]) 

168 self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 1, ax=ax1) 

169 ax1.scatter(electrodes['mlapdv'][:, 0] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r') 

170 ax1.set_title(f"{self.pid_label}") 

171 

172 ax2 = fig.add_subplot(gs[1, 0]) 

173 self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 0, ax=ax2) 

174 ax2.scatter(electrodes['mlapdv'][:, 1] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r') 

175 

176 ax3 = fig.add_subplot(gs[:, 1]) 

177 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=ax3, 

178 title=self.histology_status) 

179 

180 save_path = Path(self.output_directory).joinpath("histology_slices.png") 

181 output_files.append(save_path) 

182 fig.savefig(save_path) 

183 plt.close(fig) 

184 

185 return output_files 

186 

187 def get_probe_signature(self): 

188 input_signature = [('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False), 

189 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False), 

190 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)] 

191 output_signature = [('histology_slices.png', f'snapshot/{self.pname}', True)] 

192 self.signature = {'input_files': input_signature, 'output_files': output_signature} 

193 

194 

195class LfpPlots(ReportSnapshotProbe): 

196 """ 

197 Plots LFP spectrum and LFP RMS plots 

198 """ 

199 

200 def _run(self): 

201 

202 assert self.pid 

203 

204 output_files = [] 

205 

206 if self.location != 'server': 

207 self.histology_status = self.get_histology_status() 

208 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}') 

209 

210 # lfp spectrum 

211 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

212 lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysSpectralDensityLF', 

213 namespace='iblqc') 

214 _, _, _ = image_lfp_spectrum_plot(lfp.power, lfp.freqs, clim=[-65, -95], fig_kwargs={'figsize': (8, 6)}, ax=axs[0], 

215 display=True, title=f"{self.pid_label}") 

216 set_axis_label_size(axs[0], cmap=True) 

217 if self.histology_status: 

218 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1], 

219 title=self.histology_status) 

220 set_axis_label_size(axs[1]) 

221 else: 

222 remove_axis_outline(axs[1]) 

223 

224 save_path = Path(self.output_directory).joinpath("lfp_spectrum.png") 

225 output_files.append(save_path) 

226 fig.savefig(save_path) 

227 plt.close(fig) 

228 

229 # lfp rms 

230 # TODO need to figure out the clim range 

231 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

232 lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsLF', namespace='iblqc') 

233 _, _, _ = image_rms_plot(lfp.rms, lfp.timestamps, median_subtract=False, band='LFP', clim=[-35, -45], ax=axs[0], 

234 cmap='inferno', fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}") 

235 set_axis_label_size(axs[0], cmap=True) 

236 if self.histology_status: 

237 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1], 

238 title=self.histology_status) 

239 set_axis_label_size(axs[1]) 

240 else: 

241 remove_axis_outline(axs[1]) 

242 

243 save_path = Path(self.output_directory).joinpath("lfp_rms.png") 

244 output_files.append(save_path) 

245 fig.savefig(save_path) 

246 plt.close(fig) 

247 

248 return output_files 

249 

250 def get_probe_signature(self): 

251 input_signature = [('_iblqc_ephysTimeRmsLF.rms.npy', f'raw_ephys_data/{self.pname}', True), 

252 ('_iblqc_ephysTimeRmsLF.timestamps.npy', f'raw_ephys_data/{self.pname}', True), 

253 ('_iblqc_ephysSpectralDensityLF.freqs.npy', f'raw_ephys_data/{self.pname}', True), 

254 ('_iblqc_ephysSpectralDensityLF.power.npy', f'raw_ephys_data/{self.pname}', True), 

255 ('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False), 

256 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False), 

257 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)] 

258 output_signature = [('lfp_spectrum.png', f'snapshot/{self.pname}', True), 

259 ('lfp_rms.png', f'snapshot/{self.pname}', True)] 

260 self.signature = {'input_files': input_signature, 'output_files': output_signature} 

261 

262 

263class ApPlots(ReportSnapshotProbe): 

264 """ 

265 Plots AP RMS plots 

266 """ 

267 

268 def _run(self): 

269 

270 assert self.pid 

271 

272 output_files = [] 

273 

274 if self.location != 'server': 

275 self.histology_status = self.get_histology_status() 

276 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}') 

277 

278 # TODO need to figure out the clim range 

279 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

280 ap = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsAP', namespace='iblqc') 

281 _, _, _ = image_rms_plot(ap.rms, ap.timestamps, median_subtract=False, band='AP', ax=axs[0], 

282 fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}") 

283 set_axis_label_size(axs[0], cmap=True) 

284 if self.histology_status: 

285 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1], 

286 title=self.histology_status) 

287 set_axis_label_size(axs[1]) 

288 else: 

289 remove_axis_outline(axs[1]) 

290 

291 save_path = Path(self.output_directory).joinpath("ap_rms.png") 

292 output_files.append(save_path) 

293 fig.savefig(save_path) 

294 plt.close(fig) 

295 

296 return output_files 

297 

298 def get_probe_signature(self): 

299 input_signature = [('_iblqc_ephysTimeRmsAP.rms.npy', f'raw_ephys_data/{self.pname}', True), 

300 ('_iblqc_ephysTimeRmsAP.timestamps.npy', f'raw_ephys_data/{self.pname}', True), 

301 ('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False), 

302 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False), 

303 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)] 

304 output_signature = [('ap_rms.png', f'snapshot/{self.pname}', True)] 

305 self.signature = {'input_files': input_signature, 'output_files': output_signature} 

306 

307 

308class SpikeSorting(ReportSnapshotProbe): 

309 """ 

310 Plots raw electrophysiology AP band 

311 :param session_path: session path 

312 :param probe_id: str, UUID of the probe insertion for which to create the plot 

313 :param **kwargs: keyword arguments passed to tasks.Task 

314 """ 

315 

316 def _run(self, collection=None): 

317 """runs for initiated PID, streams data, destripe and check bad channels""" 

318 

319 def plot_driftmap(self, spikes, clusters, channels, collection): 

320 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

321 driftmap(spikes.times, spikes.depths, t_bin=0.007, d_bin=10, vmax=0.5, ax=axs[0]) 

322 title_str = f"{self.pid_label}, {collection}, {self.pid} \n " \ 

323 f"{spikes.clusters.size:_} spikes, {clusters.depths.size:_} clusters" 

324 ylim = (0, np.max(channels['axial_um'])) 

325 axs[0].set(ylim=ylim, title=title_str) 

326 run_label = str(Path(collection).relative_to(f'alf/{self.pname}')) 

327 run_label = "ks2matlab" if run_label == '.' else run_label 

328 outfile = self.output_directory.joinpath(f"spike_sorting_raster_{run_label}.png") 

329 set_axis_label_size(axs[0]) 

330 

331 if self.histology_status: 

332 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 

333 brain_regions=self.brain_regions, display=True, ax=axs[1], title=self.histology_status) 

334 axs[1].set(ylim=ylim) 

335 set_axis_label_size(axs[1]) 

336 else: 

337 remove_axis_outline(axs[1]) 

338 

339 fig.savefig(outfile) 

340 plt.close(fig) 

341 

342 return outfile, fig, axs 

343 

344 output_files = [] 

345 if self.location == 'server': 

346 assert collection 

347 spikes = alfio.load_object(self.session_path.joinpath(collection), 'spikes') 

348 clusters = alfio.load_object(self.session_path.joinpath(collection), 'clusters') 

349 channels = alfio.load_object(self.session_path.joinpath(collection), 'channels') 

350 channels['axial_um'] = channels['localCoordinates'][:, 1] 

351 

352 out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection) 

353 output_files.append(out) 

354 

355 else: 

356 self.histology_status = self.get_histology_status() 

357 all_here, output_files = self.assert_expected(self.output_files, silent=True) 

358 spike_sorting_runs = self.one.list_datasets(self.eid, filename='spikes.times.npy', collection=f'alf/{self.pname}*') 

359 if all_here and len(output_files) == len(spike_sorting_runs): 

360 return output_files 

361 logger.info(self.output_directory) 

362 for run in spike_sorting_runs: 

363 collection = str(Path(run).parent.as_posix()) 

364 spikes, clusters, channels = load_spike_sorting_fast( 

365 eid=self.eid, probe=self.pname, one=self.one, nested=False, collection=collection, 

366 dataset_types=['spikes.depths'], brain_regions=self.brain_regions) 

367 

368 if 'atlas_id' not in channels.keys(): 

369 channels = self.get_channels('channels', collection) 

370 

371 out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection) 

372 output_files.append(out) 

373 

374 return output_files 

375 

376 def get_probe_signature(self): 

377 input_signature = [('spikes.times.npy', f'alf/{self.pname}*', True), 

378 ('spikes.amps.npy', f'alf/{self.pname}*', True), 

379 ('spikes.depths.npy', f'alf/{self.pname}*', True), 

380 ('clusters.depths.npy', f'alf/{self.pname}*', True), 

381 ('channels.localCoordinates.npy', f'alf/{self.pname}*', False), 

382 ('channels.mlapdv.npy', f'alf/{self.pname}*', False), 

383 ('channels.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}*', False)] 

384 output_signature = [('spike_sorting_raster*.png', f'snapshot/{self.pname}', True)] 

385 self.signature = {'input_files': input_signature, 'output_files': output_signature} 

386 

387 def get_signatures(self, **kwargs): 

388 files_spikes = Path(self.session_path).joinpath('alf').rglob('spikes.times.npy') 

389 folder_probes = [f.parent for f in files_spikes] 

390 

391 full_input_files = [] 

392 for sig in self.signature['input_files']: 

393 for folder in folder_probes: 

394 full_input_files.append((sig[0], str(folder.relative_to(self.session_path)), sig[2])) 

395 if len(full_input_files) != 0: 

396 self.input_files = full_input_files 

397 else: 

398 self.input_files = self.signature['input_files'] 

399 

400 self.output_files = self.signature['output_files'] 

401 

402 

403class BadChannelsAp(ReportSnapshotProbe): 

404 """ 

405 Plots raw electrophysiology AP band 

406 task = BadChannelsAp(pid, one=one=one) 

407 :param session_path: session path 

408 :param probe_id: str, UUID of the probe insertion for which to create the plot 

409 :param **kwargs: keyword arguments passed to tasks.Task 

410 """ 

411 

412 def get_probe_signature(self): 

413 pname = self.pname 

414 input_signature = [('*ap.meta', f'raw_ephys_data/{pname}', True), 

415 ('*ap.ch', f'raw_ephys_data/{pname}', False)] 

416 output_signature = [('raw_ephys_bad_channels.png', f'snapshot/{pname}', True), 

417 ('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True), 

418 ('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True), 

419 ('raw_ephys_bad_channels_destripe.png', f'snapshot/{pname}', True), 

420 ('raw_ephys_bad_channels_difference.png', f'snapshot/{pname}', True), 

421 ] 

422 self.signature = {'input_files': input_signature, 'output_files': output_signature} 

423 

424 def _run(self): 

425 """runs for initiated PID, streams data, destripe and check bad channels""" 

426 assert self.pid 

427 self.eqcs = [] 

428 T0 = 60 * 30 

429 SNAPSHOT_LABEL = "raw_ephys_bad_channels" 

430 output_files = list(self.output_directory.glob(f'{SNAPSHOT_LABEL}*')) 

431 if len(output_files) == 4: 

432 return output_files 

433 

434 self.output_directory.mkdir(exist_ok=True, parents=True) 

435 

436 if self.location != 'server': 

437 self.histology_status = self.get_histology_status() 

438 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}') 

439 

440 if 'atlas_id' in electrodes.keys(): 

441 electrodes['ibr'] = ismember(electrodes['atlas_id'], self.brain_regions.id)[1] 

442 electrodes['acronym'] = self.brain_regions.acronym[electrodes['ibr']] 

443 electrodes['name'] = self.brain_regions.name[electrodes['ibr']] 

444 electrodes['title'] = self.histology_status 

445 else: 

446 electrodes = None 

447 

448 nsecs = 1 

449 sr = Streamer(pid=self.pid, one=self.one, remove_cached=False, typ='ap') 

450 s0 = T0 * sr.fs 

451 tsel = slice(int(s0), int(s0) + int(nsecs * sr.fs)) 

452 # Important: remove sync channel from raw data, and transpose 

453 raw = sr[tsel, :-sr.nsync].T 

454 

455 else: 

456 electrodes = None 

457 ap_file = next(self.session_path.joinpath('raw_ephys_data', self.pname).glob('*ap.*bin'), None) 

458 if ap_file is not None: 

459 sr = spikeglx.Reader(ap_file) 

460 # If T0 is greater than recording length, take 500 sec before end 

461 if sr.rl < T0: 

462 T0 = int(sr.rl - 500) 

463 raw = sr[int((sr.fs * T0)):int((sr.fs * (T0 + 1))), :-sr.nsync].T 

464 else: 

465 return [] 

466 

467 if sr.meta.get('NP2.4_shank', None) is not None: 

468 h = neuropixel.trace_header(sr.major_version, nshank=4) 

469 h = neuropixel.split_trace_header(h, shank=int(sr.meta.get('NP2.4_shank'))) 

470 else: 

471 h = neuropixel.trace_header(sr.major_version, nshank=np.unique(sr.geometry['shank']).size) 

472 

473 channel_labels, channel_features = voltage.detect_bad_channels(raw, sr.fs) 

474 _, eqcs, output_files = ephys_bad_channels( 

475 raw=raw, fs=sr.fs, channel_labels=channel_labels, channel_features=channel_features, h=h, channels=electrodes, 

476 title=SNAPSHOT_LABEL, destripe=True, save_dir=self.output_directory, br=self.brain_regions, pid_info=self.pid_label) 

477 self.eqcs = eqcs 

478 return output_files 

479 

480 

481def ephys_bad_channels(raw, fs, channel_labels, channel_features, h=None, channels=None, title="ephys_bad_channels", 

482 save_dir=None, destripe=False, eqcs=None, br=None, pid_info=None, plot_backend='matplotlib'): 

483 nc, ns = raw.shape 

484 rl = ns / fs 

485 

486 def gain2level(gain): 

487 return 10 ** (gain / 20) * 4 * np.array([-1, 1]) 

488 

489 if fs >= 2600: # AP band 

490 ylim_rms = [0, 100] 

491 ylim_psd_hf = [0, 0.1] 

492 eqc_xrange = [450, 500] 

493 butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'} 

494 eqc_gain = - 90 

495 eqc_levels = gain2level(eqc_gain) 

496 else: 

497 # we are working with the LFP 

498 ylim_rms = [0, 1000] 

499 ylim_psd_hf = [0, 1] 

500 eqc_xrange = [450, 950] 

501 butter_kwargs = {'N': 3, 'Wn': np.array([2, 125]) / fs * 2, 'btype': 'bandpass'} 

502 eqc_gain = - 78 

503 eqc_levels = gain2level(eqc_gain) 

504 

505 inoisy = np.where(channel_labels == 2)[0] 

506 idead = np.where(channel_labels == 1)[0] 

507 ioutside = np.where(channel_labels == 3)[0] 

508 

509 # display voltage traces 

510 eqcs = [] if eqcs is None else eqcs 

511 # butterworth, for display only 

512 sos = scipy.signal.butter(**butter_kwargs, output='sos') 

513 butt = scipy.signal.sosfiltfilt(sos, raw) 

514 

515 if plot_backend == 'matplotlib': 

516 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

517 eqcs.append(Density(butt, fs=fs, taxis=1, ax=axs[0], title='highpass', vmin=eqc_levels[0], vmax=eqc_levels[1])) 

518 

519 if destripe: 

520 dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels) 

521 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

522 eqcs.append(Density( 

523 dest, fs=fs, taxis=1, ax=axs[0], title='destripe', vmin=eqc_levels[0], vmax=eqc_levels[1])) 

524 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 

525 eqcs.append(Density((butt - dest), fs=fs, taxis=1, ax=axs[0], title='difference', vmin=eqc_levels[0], 

526 vmax=eqc_levels[1])) 

527 

528 for eqc in eqcs: 

529 y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500)) 

530 eqc.ax.scatter(x.flatten(), y.flatten(), c='goldenrod', s=4) 

531 y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500)) 

532 eqc.ax.scatter(x.flatten(), y.flatten(), c='r', s=4) 

533 y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500)) 

534 eqc.ax.scatter(x.flatten(), y.flatten(), c='b', s=4) 

535 

536 eqc.ax.set_xlim(*eqc_xrange) 

537 eqc.ax.set_ylim(0, nc) 

538 eqc.ax.set_ylabel('Channel index') 

539 eqc.ax.set_title(f'{pid_info}_{eqc.title}') 

540 set_axis_label_size(eqc.ax) 

541 

542 ax = eqc.figure.axes[1] 

543 if channels is not None: 

544 chn_title = channels.get('title', None) 

545 plot_brain_regions(channels['atlas_id'], brain_regions=br, display=True, ax=ax, 

546 title=chn_title) 

547 set_axis_label_size(ax) 

548 else: 

549 remove_axis_outline(ax) 

550 else: 

551 from viewspikes.gui import viewephys # noqa 

552 eqcs.append(viewephys(butt, fs=fs, channels=channels, title='highpass', br=br)) 

553 

554 if destripe: 

555 dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels) 

556 eqcs.append(viewephys(dest, fs=fs, channels=channels, title='destripe', br=br)) 

557 eqcs.append(viewephys((butt - dest), fs=fs, channels=channels, title='difference', br=br)) 

558 

559 for eqc in eqcs: 

560 y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500)) 

561 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(164, 142, 35), label='outside') 

562 y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500)) 

563 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(255, 0, 0), label='noisy') 

564 y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500)) 

565 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(0, 0, 255), label='dead') 

566 

567 eqcs[0].ctrl.set_gain(eqc_gain) 

568 eqcs[0].resize(1960, 1200) 

569 eqcs[0].viewBox_seismic.setXRange(*eqc_xrange) 

570 eqcs[0].viewBox_seismic.setYRange(0, nc) 

571 eqcs[0].ctrl.propagate() 

572 

573 # display features 

574 fig, axs = plt.subplots(2, 2, sharex=True, figsize=[16, 9], tight_layout=True) 

575 fig.suptitle(title) 

576 axs[0, 0].plot(channel_features['rms_raw'] * 1e6) 

577 axs[0, 0].set(title='rms', xlabel='channel number', ylabel='rms (uV)', ylim=ylim_rms) 

578 

579 axs[1, 0].plot(channel_features['psd_hf']) 

580 axs[1, 0].plot(inoisy, np.minimum(channel_features['psd_hf'][inoisy], 0.0999), 'xr') 

581 axs[1, 0].set(title='PSD above 80% Nyquist', xlabel='channel number', ylabel='PSD (uV ** 2 / Hz)', ylim=ylim_psd_hf) 

582 axs[1, 0].legend = ['psd', 'noisy'] 

583 

584 axs[0, 1].plot(channel_features['xcor_hf']) 

585 axs[0, 1].plot(channel_features['xcor_lf']) 

586 

587 axs[0, 1].plot(idead, channel_features['xcor_hf'][idead], 'xb') 

588 axs[0, 1].plot(ioutside, channel_features['xcor_lf'][ioutside], 'xy') 

589 axs[0, 1].set(title='Similarity', xlabel='channel number', ylabel='', ylim=[-1.5, 0.5]) 

590 axs[0, 1].legend(['detrend', 'trend', 'dead', 'outside']) 

591 

592 fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz 

593 axs[1, 1].imshow(20 * np.log10(psd).T, extent=[0, nc - 1, fscale[0], fscale[-1]], origin='lower', aspect='auto', 

594 vmin=-50, vmax=-20) 

595 axs[1, 1].set(title='PSD', xlabel='channel number', ylabel="Frequency (Hz)") 

596 axs[1, 1].plot(idead, idead * 0 + fs / 4, 'xb') 

597 axs[1, 1].plot(inoisy, inoisy * 0 + fs / 4, 'xr') 

598 axs[1, 1].plot(ioutside, ioutside * 0 + fs / 4, 'xy') 

599 

600 if save_dir is not None: 

601 output_files = [Path(save_dir).joinpath(f"{title}.png")] 

602 fig.savefig(output_files[0]) 

603 for eqc in eqcs: 

604 if plot_backend == 'matplotlib': 

605 output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.title}.png")) 

606 eqc.figure.savefig(str(output_files[-1])) 

607 else: 

608 output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.windowTitle()}.png")) 

609 eqc.grab().save(str(output_files[-1])) 

610 return fig, eqcs, output_files 

611 else: 

612 return fig, eqcs 

613 

614 

615def raw_destripe(raw, fs, t0, i_plt, n_plt, 

616 fig=None, axs=None, savedir=None, detect_badch=True, 

617 SAMPLE_SKIP=200, DISPLAY_TIME=0.05, N_CHAN=384, 

618 MIN_X=-0.00011, MAX_X=0.00011): 

619 ''' 

620 :param raw: raw ephys data, Ns x Nc, x-axis: time (s), y-axis: channel 

621 :param fs: sampling freq (Hz) of the raw ephys data 

622 :param t0: time (s) of ephys sample beginning from session start 

623 :param i_plt: increment of plot to display image one (start from 0, has to be < n_plt) 

624 :param n_plt: total number of subplot on figure 

625 :param fig: figure handle 

626 :param axs: axis handle 

627 :param savedir: filename, including directory, to save figure to 

628 :param detect_badch: boolean, to detect or not bad channels 

629 :param SAMPLE_SKIP: number of samples to skip at origin of ephsy sample for display 

630 :param DISPLAY_TIME: time (s) to display 

631 :param N_CHAN: number of expected channels on the probe 

632 :param MIN_X: max voltage for color range 

633 :param MAX_X: min voltage for color range 

634 :return: fig, axs 

635 ''' 

636 

637 # Import 

638 from ibldsp import voltage 

639 from ibllib.plots import Density 

640 

641 # Init fig 

642 if fig is None or axs is None: 

643 fig, axs = plt.subplots(nrows=1, ncols=n_plt, figsize=(14, 5), gridspec_kw={'width_ratios': 4 * n_plt}) 

644 

645 if i_plt > len(axs) - 1: # Error 

646 raise ValueError(f'The given increment of subplot ({i_plt+1}) ' 

647 f'is larger than the total number of subplots ({len(axs)})') 

648 

649 [nc, ns] = raw.shape 

650 if nc == N_CHAN: 

651 destripe = voltage.destripe(raw, fs=fs) 

652 X = destripe[:, :int(DISPLAY_TIME * fs)].T 

653 Xs = X[SAMPLE_SKIP:].T # Remove artifact at beginning 

654 Tplot = Xs.shape[1] / fs 

655 

656 # PLOT RAW DATA 

657 d = Density(-Xs, fs=fs, taxis=1, ax=axs[i_plt], vmin=MIN_X, vmax=MAX_X) # noqa 

658 axs[i_plt].set_ylabel('') 

659 axs[i_plt].set_xlim((0, Tplot * 1e3)) 

660 axs[i_plt].set_ylim((0, nc)) 

661 

662 # Init title 

663 title_plt = f't0 = {int(t0 / 60)} min' 

664 

665 if detect_badch: 

666 # Detect and remove bad channels prior to spike detection 

667 labels, xfeats = voltage.detect_bad_channels(raw, fs) 

668 idx_badchan = np.where(labels != 0)[0] 

669 # Plot bad channels on raw data 

670 x, y = np.meshgrid(idx_badchan, np.linspace(0, Tplot * 1e3, 20)) 

671 axs[i_plt].plot(y.flatten(), x.flatten(), '.k', markersize=1) 

672 # Append title 

673 title_plt += f', n={len(idx_badchan)} bad ch' 

674 

675 # Set title 

676 axs[i_plt].title.set_text(title_plt) 

677 

678 else: 

679 axs[i_plt].title.set_text(f'CANNOT DESTRIPE, N CHAN = {nc}') 

680 

681 # Amend some axis style 

682 if i_plt > 0: 

683 axs[i_plt].set_yticklabels('') 

684 

685 # Fig layout 

686 fig.tight_layout() 

687 if savedir is not None: 

688 fig.savefig(fname=savedir) 

689 

690 return fig, axs 

691 

692 

693def dlc_qc_plot(session_path, one=None, device_collection='raw_video_data', 

694 cameras=('left', 'right', 'body'), trials_collection='alf'): 

695 """ 

696 Creates DLC QC plot. 

697 Data is searched first locally, then on Alyx. Panels that lack required data are skipped. 

698 

699 Required data to create all panels 

700 'raw_video_data/_iblrig_bodyCamera.raw.mp4', 

701 'raw_video_data/_iblrig_leftCamera.raw.mp4', 

702 'raw_video_data/_iblrig_rightCamera.raw.mp4', 

703 'alf/_ibl_bodyCamera.dlc.pqt', 

704 'alf/_ibl_leftCamera.dlc.pqt', 

705 'alf/_ibl_rightCamera.dlc.pqt', 

706 'alf/_ibl_bodyCamera.times.npy', 

707 'alf/_ibl_leftCamera.times.npy', 

708 'alf/_ibl_rightCamera.times.npy', 

709 'alf/_ibl_leftCamera.features.pqt', 

710 'alf/_ibl_rightCamera.features.pqt', 

711 'alf/rightROIMotionEnergy.position.npy', 

712 'alf/leftROIMotionEnergy.position.npy', 

713 'alf/bodyROIMotionEnergy.position.npy', 

714 'alf/_ibl_trials.choice.npy', 

715 'alf/_ibl_trials.feedbackType.npy', 

716 'alf/_ibl_trials.feedback_times.npy', 

717 'alf/_ibl_trials.stimOn_times.npy', 

718 'alf/_ibl_wheel.position.npy', 

719 'alf/_ibl_wheel.timestamps.npy', 

720 'alf/licks.times.npy', 

721 

722 :params session_path: Path to session data on disk 

723 :params one: ONE instance, if None is given, default ONE is instantiated 

724 :returns: Matplotlib figure 

725 """ 

726 

727 one = one or ONE() 1b

728 # hack for running on cortexlab local server 

729 if one.alyx.base_url == 'https://alyx.cortexlab.net': 1b

730 one = ONE(base_url='https://alyx.internationalbrainlab.org') 

731 data = {} 1b

732 session_path = Path(session_path) 1b

733 

734 # Load data for each camera 

735 for cam in cameras: 1b

736 # Load a single frame for each video 

737 # Check if video data is available locally,if yes, load a single frame 

738 video_path = session_path.joinpath(device_collection, f'_iblrig_{cam}Camera.raw.mp4') 1b

739 if video_path.exists(): 1b

740 data[f'{cam}_frame'] = get_video_frame(video_path, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0] 

741 # If not, try to stream a frame (try three times) 

742 else: 

743 try: 1b

744 video_url = url_from_eid(one.path2eid(session_path), one=one)[cam] 1b

745 for tries in range(3): 

746 try: 

747 data[f'{cam}_frame'] = get_video_frame(video_url, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0] 

748 break 

749 except Exception: 

750 if tries < 2: 

751 tries += 1 

752 logger.info(f"Streaming {cam} video failed, retrying x{tries}") 

753 time.sleep(30) 

754 else: 

755 logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.") 

756 data[f'{cam}_frame'] = None 

757 except KeyError: 1b

758 logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.") 1b

759 data[f'{cam}_frame'] = None 1b

760 # Other camera associated data 

761 for feat in ['dlc', 'times', 'features', 'ROIMotionEnergy']: 1b

762 # Check locally first, then try to load from alyx, if nothing works, set to None 

763 if feat == 'features' and cam == 'body': # this doesn't exist for body cam 1b

764 continue 1b

765 local_file = list(session_path.joinpath('alf').glob(f'*{cam}Camera.{feat}*')) 1b

766 if len(local_file) > 0: 1b

767 data[f'{cam}_{feat}'] = alfio.load_file_content(local_file[0]) 

768 else: 

769 alyx_ds = [ds for ds in one.list_datasets(one.path2eid(session_path)) if f'{cam}Camera.{feat}' in ds] 1b

770 if len(alyx_ds) > 0: 1b

771 data[f'{cam}_{feat}'] = one.load_dataset(one.path2eid(session_path), alyx_ds[0]) 

772 else: 

773 logger.warning(f"Could not load _ibl_{cam}Camera.{feat} some plots have to be skipped.") 1b

774 data[f'{cam}_{feat}'] = None 1b

775 # Sometimes there is a file but the object is empty, set to None 

776 if data[f'{cam}_{feat}'] is not None and len(data[f'{cam}_{feat}']) == 0: 1b

777 logger.warning(f"Object loaded from _ibl_{cam}Camera.{feat} is empty, some plots have to be skipped.") 

778 data[f'{cam}_{feat}'] = None 

779 

780 # If we have no frame and/or no DLC and/or no times for all cams, raise an error, something is really wrong 

781 assert any(data[f'{cam}_frame'] is not None for cam in cameras), "No camera data could be loaded, aborting." 1b

782 assert any(data[f'{cam}_dlc'] is not None for cam in cameras), "No DLC data could be loaded, aborting." 

783 assert any(data[f'{cam}_times'] is not None for cam in cameras), "No camera times data could be loaded, aborting." 

784 

785 # Load session level data 

786 for alf_object, collection in zip(['trials', 'wheel', 'licks'], [trials_collection, trials_collection, 'alf']): 

787 try: 

788 data[f'{alf_object}'] = alfio.load_object(session_path.joinpath(collection), alf_object) # load locally 

789 continue 

790 except ALFObjectNotFound: 

791 pass 

792 try: 

793 # then try from alyx 

794 data[f'{alf_object}'] = one.load_object(one.path2eid(session_path), alf_object, collection=collection) 

795 except ALFObjectNotFound: 

796 logger.warning(f"Could not load {alf_object} object, some plots have to be skipped.") 

797 data[f'{alf_object}'] = None 

798 

799 # Simplify and clean up trials data 

800 if data['trials']: 

801 data['trials'] = pd.DataFrame( 

802 {k: data['trials'][k] for k in ['stimOn_times', 'feedback_times', 'choice', 'feedbackType']}) 

803 # Discard nan events and too long trials 

804 data['trials'] = data['trials'].dropna() 

805 data['trials'] = data['trials'].drop( 

806 data['trials'][(data['trials']['feedback_times'] - data['trials']['stimOn_times']) > 10].index) 

807 

808 # Make a list of panels, if inputs are missing, instead input a text to display 

809 panels = [] 

810 # Panel A, B, C: Trace on frame 

811 for cam in cameras: 

812 if data[f'{cam}_frame'] is not None and data[f'{cam}_dlc'] is not None: 

813 panels.append((plot_trace_on_frame, 

814 {'frame': data[f'{cam}_frame'], 'dlc_df': data[f'{cam}_dlc'], 'cam': cam})) 

815 else: 

816 panels.append((None, f'Data missing\n{cam.capitalize()} cam trace on frame')) 

817 

818 # If trials data is not there, we cannot plot any of the trial average plots, skip all remaining panels 

819 if data['trials'] is None: 

820 panels.extend([(None, 'No trial data,\ncannot compute trial avgs')] * 7) 

821 else: 

822 # Panel D: Motion energy 

823 camera_dict = {} 

824 for cam in cameras: # Remove cameras where we don't have motion energy AND camera times 

825 d = {'motion_energy': data.get(f'{cam}_ROIMotionEnergy'), 'times': data.get(f'{cam}_times')} 

826 if not any(x is None for x in d.values()): 

827 camera_dict[cam] = d 

828 if len(camera_dict) > 0: 

829 panels.append((plot_motion_energy_hist, {'camera_dict': camera_dict, 'trials_df': data['trials']})) 

830 else: 

831 panels.append((None, 'Data missing\nMotion energy')) 

832 

833 # Panel E: Wheel position 

834 if data['wheel']: 

835 panels.append((plot_wheel_position, {'wheel_position': data['wheel'].position, 

836 'wheel_time': data['wheel'].timestamps, 

837 'trials_df': data['trials']})) 

838 else: 

839 panels.append((None, 'Data missing\nWheel position')) 

840 

841 # Panel F, G: Paw speed and nose speed 

842 # Try if all data is there for left cam first, otherwise right 

843 for cam in ['left', 'right']: 

844 fail = False 

845 if (data[f'{cam}_dlc'] is not None and data[f'{cam}_times'] is not None 

846 and len(data[f'{cam}_times']) >= len(data[f'{cam}_dlc'])): 

847 break 

848 fail = True 

849 if not fail: 

850 paw = 'r' if cam == 'left' else 'l' 

851 panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'], 

852 'trials_df': data['trials'], 'feature': f'paw_{paw}', 'cam': cam})) 

853 panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'], 

854 'trials_df': data['trials'], 'feature': 'nose_tip', 'legend': False, 

855 'cam': cam})) 

856 else: 

857 panels.extend([(None, 'Data missing or corrupt\nSpeed histograms')] * 2) 

858 

859 # Panel H and I: Lick plots 

860 if data['licks'] and data['licks'].times.shape[0] > 0: 

861 panels.append((plot_lick_hist, {'lick_times': data['licks'].times, 'trials_df': data['trials']})) 

862 panels.append((plot_lick_raster, {'lick_times': data['licks'].times, 'trials_df': data['trials']})) 

863 else: 

864 panels.extend([(None, 'Data missing\nLicks plots') for i in range(2)]) 

865 

866 # Panel J: pupil plot 

867 # Try if all data is there for left cam first, otherwise right 

868 for cam in ['left', 'right']: 

869 fail = False 

870 if (data.get(f'{cam}_times') is not None and data.get(f'{cam}_features') is not None 

871 and len(data[f'{cam}_times']) >= len(data[f'{cam}_features']) 

872 and not np.all(np.isnan(data[f'{cam}_features'].pupilDiameter_smooth))): 

873 break 

874 fail = True 

875 if not fail: 

876 panels.append((plot_pupil_diameter_hist, 

877 {'pupil_diameter': data[f'{cam}_features'].pupilDiameter_smooth, 

878 'cam_times': data[f'{cam}_times'], 'trials_df': data['trials'], 'cam': cam})) 

879 else: 

880 panels.append((None, 'Data missing or corrupt\nPupil diameter')) 

881 

882 # Plotting 

883 plt.rcParams.update({'font.size': 10}) 

884 fig = plt.figure(figsize=(17, 10)) 

885 for i, panel in enumerate(panels): 

886 ax = plt.subplot(2, 5, i + 1) 

887 ax.text(-0.1, 1.15, ascii_uppercase[i], transform=ax.transAxes, fontsize=16, fontweight='bold') 

888 # Check if there was in issue with inputs, if yes, print the respective text 

889 if panel[0] is None: 

890 ax.text(.5, .5, panel[1], color='r', fontweight='bold', fontsize=12, horizontalalignment='center', 

891 verticalalignment='center', transform=ax.transAxes) 

892 plt.axis('off') 

893 else: 

894 try: 

895 panel[0](**panel[1]) 

896 except Exception: 

897 logger.error(f'Error in {panel[0].__name__}\n' + traceback.format_exc()) 

898 ax.text(.5, .5, f'Error while plotting\n{panel[0].__name__}', color='r', fontweight='bold', 

899 fontsize=12, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes) 

900 plt.axis('off') 

901 plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 

902 

903 return fig