Coverage for ibllib/plots/figures.py: 68%

511 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1""" 

2Module that produces figures, usually for the extraction pipeline 

3""" 

4import logging 

5import time 

6from pathlib import Path 

7import traceback 

8from string import ascii_uppercase 

9 

10import numpy as np 

11import pandas as pd 

12import scipy.signal 

13import matplotlib.pyplot as plt 

14 

15from neurodsp import voltage 

16from ibllib.plots.snapshot import ReportSnapshotProbe, ReportSnapshot 

17from one.api import ONE 

18import one.alf.io as alfio 

19from one.alf.exceptions import ALFObjectNotFound 

20from ibllib.io.video import get_video_frame, url_from_eid 

21import spikeglx 

22import neuropixel 

23from brainbox.plot import driftmap 

24from brainbox.io.spikeglx import Streamer 

25from brainbox.behavior.dlc import SAMPLING, plot_trace_on_frame, plot_wheel_position, plot_lick_hist, \ 

26 plot_lick_raster, plot_motion_energy_hist, plot_speed_hist, plot_pupil_diameter_hist 

27from brainbox.ephys_plots import image_lfp_spectrum_plot, image_rms_plot, plot_brain_regions 

28from brainbox.io.one import load_spike_sorting_fast 

29from brainbox.behavior import training 

30from iblutil.numerical import ismember 

31from ibllib.plots.misc import Density 

32 

33 

34logger = logging.getLogger(__name__) 

35 

36 

37def set_axis_label_size(ax, labels=14, ticklabels=12, title=14, cmap=False): 

38 """ 

39 Function to normalise size of all axis labels 

40 :param ax: 

41 :param labels: 

42 :param ticklabels: 

43 :param title: 

44 :param cmap: 

45 :return: 

46 """ 

47 

48 ax.xaxis.get_label().set_fontsize(labels) 1a

49 ax.yaxis.get_label().set_fontsize(labels) 1a

50 ax.tick_params(labelsize=ticklabels) 1a

51 ax.title.set_fontsize(title) 1a

52 

53 if cmap: 1a

54 cbar = ax.images[-1].colorbar 1a

55 cbar.ax.tick_params(labelsize=ticklabels) 1a

56 cbar.ax.yaxis.get_label().set_fontsize(labels) 1a

57 

58 

59def remove_axis_outline(ax): 

60 """ 

61 Function to remove outline of empty axis 

62 :param ax: 

63 :return: 

64 """ 

65 ax.get_xaxis().set_visible(False) 1a

66 ax.get_yaxis().set_visible(False) 1a

67 ax.spines['right'].set_visible(False) 1a

68 ax.spines['top'].set_visible(False) 1a

69 ax.spines['bottom'].set_visible(False) 1a

70 ax.spines['left'].set_visible(False) 1a

71 

72 

73class BehaviourPlots(ReportSnapshot): 

74 """ 

75 Behavioural plots 

76 """ 

77 

78 signature = { 

79 'input_files': [ 

80 ('*trials.table.pqt', 'alf', True), 

81 ], 

82 'output_files': [ 

83 ('psychometric_curve.png', 'snapshot/behaviour', True), 

84 ('chronometric_curve.png', 'snapshot/behaviour', True), 

85 ('reaction_time_with_trials.png', 'snapshot/behaviour', True) 

86 ] 

87 } 

88 

89 def __init__(self, eid, session_path=None, one=None, **kwargs): 

90 self.one = one 1a

91 self.eid = eid 1a

92 self.session_path = session_path or self.one.eid2path(self.eid) 1a

93 super(BehaviourPlots, self).__init__(self.session_path, self.eid, one=self.one, 1a

94 **kwargs) 

95 self.output_directory = self.session_path.joinpath('snapshot', 'behaviour') 1a

96 self.output_directory.mkdir(exist_ok=True, parents=True) 1a

97 

98 def _run(self): 

99 

100 output_files = [] 1a

101 trials = alfio.load_object(self.session_path.joinpath('alf'), 'trials') 1a

102 title = '_'.join(list(self.session_path.parts[-3:])) 1a

103 

104 fig, ax = training.plot_psychometric(trials, title=title, figsize=(8, 6)) 1a

105 set_axis_label_size(ax) 1a

106 save_path = Path(self.output_directory).joinpath("psychometric_curve.png") 1a

107 output_files.append(save_path) 1a

108 fig.savefig(save_path) 1a

109 plt.close(fig) 1a

110 

111 fig, ax = training.plot_reaction_time(trials, title=title, figsize=(8, 6)) 1a

112 set_axis_label_size(ax) 1a

113 save_path = Path(self.output_directory).joinpath("chronometric_curve.png") 1a

114 output_files.append(save_path) 1a

115 fig.savefig(save_path) 1a

116 plt.close(fig) 1a

117 

118 fig, ax = training.plot_reaction_time_over_trials(trials, title=title, figsize=(8, 6)) 1a

119 set_axis_label_size(ax) 1a

120 save_path = Path(self.output_directory).joinpath("reaction_time_with_trials.png") 1a

121 output_files.append(save_path) 1a

122 fig.savefig(save_path) 1a

123 plt.close(fig) 1a

124 

125 return output_files 1a

126 

127 

128# TODO put into histology and alignment pipeline 

129class HistologySlices(ReportSnapshotProbe): 

130 """ 

131 Plots coronal and sagittal slice showing electrode locations 

132 """ 

133 

134 def _run(self): 

135 

136 assert self.pid 

137 assert self.brain_atlas 

138 

139 output_files = [] 

140 self.histology_status = self.get_histology_status() 

141 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}') 

142 

143 if self.hist_lookup[self.histology_status] > 0: 

144 fig = plt.figure(figsize=(12, 9)) 

145 gs = fig.add_gridspec(2, 2, width_ratios=[.95, .05]) 

146 ax1 = fig.add_subplot(gs[0, 0]) 

147 self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 1, ax=ax1) 

148 ax1.scatter(electrodes['mlapdv'][:, 0] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r') 

149 ax1.set_title(f"{self.pid_label}") 

150 

151 ax2 = fig.add_subplot(gs[1, 0]) 

152 self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 0, ax=ax2) 

153 ax2.scatter(electrodes['mlapdv'][:, 1] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r') 

154 

155 ax3 = fig.add_subplot(gs[:, 1]) 

156 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=ax3, 

157 title=self.histology_status) 

158 

159 save_path = Path(self.output_directory).joinpath("histology_slices.png") 

160 output_files.append(save_path) 

161 fig.savefig(save_path) 

162 plt.close(fig) 

163 

164 return output_files 

165 

166 def get_probe_signature(self): 

167 input_signature = [('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False), 

168 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False), 

169 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)] 

170 output_signature = [('histology_slices.png', f'snapshot/{self.pname}', True)] 

171 self.signature = {'input_files': input_signature, 'output_files': output_signature} 

172 

173 

174class LfpPlots(ReportSnapshotProbe): 

175 """ 

176 Plots LFP spectrum and LFP RMS plots 

177 """ 

178 

179 def _run(self): 

180 

181 assert self.pid 1a

182 

183 output_files = [] 1a

184 

185 if self.location != 'server': 1a

186 self.histology_status = self.get_histology_status() 

187 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}') 

188 

189 # lfp spectrum 

190 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a

191 lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysSpectralDensityLF', 1a

192 namespace='iblqc') 

193 _, _, _ = image_lfp_spectrum_plot(lfp.power, lfp.freqs, clim=[-65, -95], fig_kwargs={'figsize': (8, 6)}, ax=axs[0], 1a

194 display=True, title=f"{self.pid_label}") 

195 set_axis_label_size(axs[0], cmap=True) 1a

196 if self.histology_status: 1a

197 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1], 

198 title=self.histology_status) 

199 set_axis_label_size(axs[1]) 

200 else: 

201 remove_axis_outline(axs[1]) 1a

202 

203 save_path = Path(self.output_directory).joinpath("lfp_spectrum.png") 1a

204 output_files.append(save_path) 1a

205 fig.savefig(save_path) 1a

206 plt.close(fig) 1a

207 

208 # lfp rms 

209 # TODO need to figure out the clim range 

210 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a

211 lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsLF', namespace='iblqc') 1a

212 _, _, _ = image_rms_plot(lfp.rms, lfp.timestamps, median_subtract=False, band='LFP', clim=[-35, -45], ax=axs[0], 1a

213 cmap='inferno', fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}") 

214 set_axis_label_size(axs[0], cmap=True) 1a

215 if self.histology_status: 1a

216 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1], 

217 title=self.histology_status) 

218 set_axis_label_size(axs[1]) 

219 else: 

220 remove_axis_outline(axs[1]) 1a

221 

222 save_path = Path(self.output_directory).joinpath("lfp_rms.png") 1a

223 output_files.append(save_path) 1a

224 fig.savefig(save_path) 1a

225 plt.close(fig) 1a

226 

227 return output_files 1a

228 

229 def get_probe_signature(self): 

230 input_signature = [('_iblqc_ephysTimeRmsLF.rms.npy', f'raw_ephys_data/{self.pname}', True), 1a

231 ('_iblqc_ephysTimeRmsLF.timestamps.npy', f'raw_ephys_data/{self.pname}', True), 

232 ('_iblqc_ephysSpectralDensityLF.freqs.npy', f'raw_ephys_data/{self.pname}', True), 

233 ('_iblqc_ephysSpectralDensityLF.power.npy', f'raw_ephys_data/{self.pname}', True), 

234 ('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False), 

235 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False), 

236 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)] 

237 output_signature = [('lfp_spectrum.png', f'snapshot/{self.pname}', True), 1a

238 ('lfp_rms.png', f'snapshot/{self.pname}', True)] 

239 self.signature = {'input_files': input_signature, 'output_files': output_signature} 1a

240 

241 

242class ApPlots(ReportSnapshotProbe): 

243 """ 

244 Plots AP RMS plots 

245 """ 

246 

247 def _run(self): 

248 

249 assert self.pid 1a

250 

251 output_files = [] 1a

252 

253 if self.location != 'server': 1a

254 self.histology_status = self.get_histology_status() 

255 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}') 

256 

257 # TODO need to figure out the clim range 

258 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a

259 ap = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsAP', namespace='iblqc') 1a

260 _, _, _ = image_rms_plot(ap.rms, ap.timestamps, median_subtract=False, band='AP', ax=axs[0], 1a

261 fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}") 

262 set_axis_label_size(axs[0], cmap=True) 1a

263 if self.histology_status: 1a

264 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1], 

265 title=self.histology_status) 

266 set_axis_label_size(axs[1]) 

267 else: 

268 remove_axis_outline(axs[1]) 1a

269 

270 save_path = Path(self.output_directory).joinpath("ap_rms.png") 1a

271 output_files.append(save_path) 1a

272 fig.savefig(save_path) 1a

273 plt.close(fig) 1a

274 

275 return output_files 1a

276 

277 def get_probe_signature(self): 

278 input_signature = [('_iblqc_ephysTimeRmsAP.rms.npy', f'raw_ephys_data/{self.pname}', True), 1a

279 ('_iblqc_ephysTimeRmsAP.timestamps.npy', f'raw_ephys_data/{self.pname}', True), 

280 ('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False), 

281 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False), 

282 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)] 

283 output_signature = [('ap_rms.png', f'snapshot/{self.pname}', True)] 1a

284 self.signature = {'input_files': input_signature, 'output_files': output_signature} 1a

285 

286 

287class SpikeSorting(ReportSnapshotProbe): 

288 """ 

289 Plots raw electrophysiology AP band 

290 :param session_path: session path 

291 :param probe_id: str, UUID of the probe insertion for which to create the plot 

292 :param **kwargs: keyword arguments passed to tasks.Task 

293 """ 

294 

295 def _run(self, collection=None): 

296 """runs for initiated PID, streams data, destripe and check bad channels""" 

297 

298 def plot_driftmap(self, spikes, clusters, channels, collection): 1a

299 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a

300 driftmap(spikes.times, spikes.depths, t_bin=0.007, d_bin=10, vmax=0.5, ax=axs[0]) 1a

301 title_str = f"{self.pid_label}, {collection}, {self.pid} \n " \ 1a

302 f"{spikes.clusters.size:_} spikes, {clusters.depths.size:_} clusters" 

303 ylim = (0, np.max(channels['axial_um'])) 1a

304 axs[0].set(ylim=ylim, title=title_str) 1a

305 run_label = str(Path(collection).relative_to(f'alf/{self.pname}')) 1a

306 run_label = "ks2matlab" if run_label == '.' else run_label 1a

307 outfile = self.output_directory.joinpath(f"spike_sorting_raster_{run_label}.png") 1a

308 set_axis_label_size(axs[0]) 1a

309 

310 if self.histology_status: 1a

311 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'], 

312 brain_regions=self.brain_regions, display=True, ax=axs[1], title=self.histology_status) 

313 axs[1].set(ylim=ylim) 

314 set_axis_label_size(axs[1]) 

315 else: 

316 remove_axis_outline(axs[1]) 1a

317 

318 fig.savefig(outfile) 1a

319 plt.close(fig) 1a

320 

321 return outfile, fig, axs 1a

322 

323 output_files = [] 1a

324 if self.location == 'server': 1a

325 assert collection 1a

326 spikes = alfio.load_object(self.session_path.joinpath(collection), 'spikes') 1a

327 clusters = alfio.load_object(self.session_path.joinpath(collection), 'clusters') 1a

328 channels = alfio.load_object(self.session_path.joinpath(collection), 'channels') 1a

329 channels['axial_um'] = channels['localCoordinates'][:, 1] 1a

330 

331 out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection) 1a

332 output_files.append(out) 1a

333 

334 else: 

335 self.histology_status = self.get_histology_status() 

336 all_here, output_files = self.assert_expected(self.output_files, silent=True) 

337 spike_sorting_runs = self.one.list_datasets(self.eid, filename='spikes.times.npy', collection=f'alf/{self.pname}*') 

338 if all_here and len(output_files) == len(spike_sorting_runs): 

339 return output_files 

340 logger.info(self.output_directory) 

341 for run in spike_sorting_runs: 

342 collection = str(Path(run).parent.as_posix()) 

343 spikes, clusters, channels = load_spike_sorting_fast( 

344 eid=self.eid, probe=self.pname, one=self.one, nested=False, collection=collection, 

345 dataset_types=['spikes.depths'], brain_regions=self.brain_regions) 

346 

347 if 'atlas_id' not in channels.keys(): 

348 channels = self.get_channels('channels', collection) 

349 

350 out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection) 

351 output_files.append(out) 

352 

353 return output_files 1a

354 

355 def get_probe_signature(self): 

356 input_signature = [('spikes.times.npy', f'alf/{self.pname}*', True), 1a

357 ('spikes.amps.npy', f'alf/{self.pname}*', True), 

358 ('spikes.depths.npy', f'alf/{self.pname}*', True), 

359 ('clusters.depths.npy', f'alf/{self.pname}*', True), 

360 ('channels.localCoordinates.npy', f'alf/{self.pname}*', False), 

361 ('channels.mlapdv.npy', f'alf/{self.pname}*', False), 

362 ('channels.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}*', False)] 

363 output_signature = [('spike_sorting_raster*.png', f'snapshot/{self.pname}', True)] 1a

364 self.signature = {'input_files': input_signature, 'output_files': output_signature} 1a

365 

366 def get_signatures(self, **kwargs): 

367 files_spikes = Path(self.session_path).joinpath('alf').rglob('spikes.times.npy') 1a

368 folder_probes = [f.parent for f in files_spikes] 1a

369 

370 full_input_files = [] 1a

371 for sig in self.signature['input_files']: 1a

372 for folder in folder_probes: 1a

373 full_input_files.append((sig[0], str(folder.relative_to(self.session_path)), sig[2])) 1a

374 if len(full_input_files) != 0: 1a

375 self.input_files = full_input_files 1a

376 else: 

377 self.input_files = self.signature['input_files'] 

378 

379 self.output_files = self.signature['output_files'] 1a

380 

381 

382class BadChannelsAp(ReportSnapshotProbe): 

383 """ 

384 Plots raw electrophysiology AP band 

385 task = BadChannelsAp(pid, one=one=one) 

386 :param session_path: session path 

387 :param probe_id: str, UUID of the probe insertion for which to create the plot 

388 :param **kwargs: keyword arguments passed to tasks.Task 

389 """ 

390 

391 def get_probe_signature(self): 

392 pname = self.pname 1a

393 input_signature = [('*ap.meta', f'raw_ephys_data/{pname}', True), 1a

394 ('*ap.ch', f'raw_ephys_data/{pname}', False)] 

395 output_signature = [('raw_ephys_bad_channels.png', f'snapshot/{pname}', True), 1a

396 ('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True), 

397 ('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True), 

398 ('raw_ephys_bad_channels_destripe.png', f'snapshot/{pname}', True), 

399 ('raw_ephys_bad_channels_difference.png', f'snapshot/{pname}', True), 

400 ] 

401 self.signature = {'input_files': input_signature, 'output_files': output_signature} 1a

402 

403 def _run(self): 

404 """runs for initiated PID, streams data, destripe and check bad channels""" 

405 assert self.pid 1a

406 self.eqcs = [] 1a

407 T0 = 60 * 30 1a

408 SNAPSHOT_LABEL = "raw_ephys_bad_channels" 1a

409 output_files = list(self.output_directory.glob(f'{SNAPSHOT_LABEL}*')) 1a

410 if len(output_files) == 4: 1a

411 return output_files 

412 

413 self.output_directory.mkdir(exist_ok=True, parents=True) 1a

414 

415 if self.location != 'server': 1a

416 self.histology_status = self.get_histology_status() 

417 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}') 

418 

419 if 'atlas_id' in electrodes.keys(): 

420 electrodes['ibr'] = ismember(electrodes['atlas_id'], self.brain_regions.id)[1] 

421 electrodes['acronym'] = self.brain_regions.acronym[electrodes['ibr']] 

422 electrodes['name'] = self.brain_regions.name[electrodes['ibr']] 

423 electrodes['title'] = self.histology_status 

424 else: 

425 electrodes = None 

426 

427 nsecs = 1 

428 sr = Streamer(pid=self.pid, one=self.one, remove_cached=False, typ='ap') 

429 s0 = T0 * sr.fs 

430 tsel = slice(int(s0), int(s0) + int(nsecs * sr.fs)) 

431 # Important: remove sync channel from raw data, and transpose 

432 raw = sr[tsel, :-sr.nsync].T 

433 

434 else: 

435 electrodes = None 1a

436 ap_file = next(self.session_path.joinpath('raw_ephys_data', self.pname).glob('*ap.*bin'), None) 1a

437 if ap_file is not None: 1a

438 sr = spikeglx.Reader(ap_file) 1a

439 # If T0 is greater than recording length, take 500 sec before end 

440 if sr.rl < T0: 1a

441 T0 = int(sr.rl - 500) 

442 raw = sr[int((sr.fs * T0)):int((sr.fs * (T0 + 1))), :-sr.nsync].T 1a

443 else: 

444 return [] 

445 

446 if sr.meta.get('NP2.4_shank', None) is not None: 1a

447 h = neuropixel.trace_header(sr.major_version, nshank=4) 

448 h = neuropixel.split_trace_header(h, shank=int(sr.meta.get('NP2.4_shank'))) 

449 else: 

450 h = neuropixel.trace_header(sr.major_version, nshank=np.unique(sr.geometry['shank']).size) 1a

451 

452 channel_labels, channel_features = voltage.detect_bad_channels(raw, sr.fs) 1a

453 _, eqcs, output_files = ephys_bad_channels( 1a

454 raw=raw, fs=sr.fs, channel_labels=channel_labels, channel_features=channel_features, h=h, channels=electrodes, 

455 title=SNAPSHOT_LABEL, destripe=True, save_dir=self.output_directory, br=self.brain_regions, pid_info=self.pid_label) 

456 self.eqcs = eqcs 1a

457 return output_files 1a

458 

459 

460def ephys_bad_channels(raw, fs, channel_labels, channel_features, h=None, channels=None, title="ephys_bad_channels", 

461 save_dir=None, destripe=False, eqcs=None, br=None, pid_info=None, plot_backend='matplotlib'): 

462 nc, ns = raw.shape 1a

463 rl = ns / fs 1a

464 

465 def gain2level(gain): 1a

466 return 10 ** (gain / 20) * 4 * np.array([-1, 1]) 1a

467 

468 if fs >= 2600: # AP band 1a

469 ylim_rms = [0, 100] 1a

470 ylim_psd_hf = [0, 0.1] 1a

471 eqc_xrange = [450, 500] 1a

472 butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'} 1a

473 eqc_gain = - 90 1a

474 eqc_levels = gain2level(eqc_gain) 1a

475 else: 

476 # we are working with the LFP 

477 ylim_rms = [0, 1000] 

478 ylim_psd_hf = [0, 1] 

479 eqc_xrange = [450, 950] 

480 butter_kwargs = {'N': 3, 'Wn': np.array([2, 125]) / fs * 2, 'btype': 'bandpass'} 

481 eqc_gain = - 78 

482 eqc_levels = gain2level(eqc_gain) 

483 

484 inoisy = np.where(channel_labels == 2)[0] 1a

485 idead = np.where(channel_labels == 1)[0] 1a

486 ioutside = np.where(channel_labels == 3)[0] 1a

487 

488 # display voltage traces 

489 eqcs = [] if eqcs is None else eqcs 1a

490 # butterworth, for display only 

491 sos = scipy.signal.butter(**butter_kwargs, output='sos') 1a

492 butt = scipy.signal.sosfiltfilt(sos, raw) 1a

493 

494 if plot_backend == 'matplotlib': 1a

495 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a

496 eqcs.append(Density(butt, fs=fs, taxis=1, ax=axs[0], title='highpass', vmin=eqc_levels[0], vmax=eqc_levels[1])) 1a

497 

498 if destripe: 1a

499 dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels) 1a

500 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a

501 eqcs.append(Density( 1a

502 dest, fs=fs, taxis=1, ax=axs[0], title='destripe', vmin=eqc_levels[0], vmax=eqc_levels[1])) 

503 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a

504 eqcs.append(Density((butt - dest), fs=fs, taxis=1, ax=axs[0], title='difference', vmin=eqc_levels[0], 1a

505 vmax=eqc_levels[1])) 

506 

507 for eqc in eqcs: 1a

508 y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500)) 1a

509 eqc.ax.scatter(x.flatten(), y.flatten(), c='goldenrod', s=4) 1a

510 y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500)) 1a

511 eqc.ax.scatter(x.flatten(), y.flatten(), c='r', s=4) 1a

512 y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500)) 1a

513 eqc.ax.scatter(x.flatten(), y.flatten(), c='b', s=4) 1a

514 

515 eqc.ax.set_xlim(*eqc_xrange) 1a

516 eqc.ax.set_ylim(0, nc) 1a

517 eqc.ax.set_ylabel('Channel index') 1a

518 eqc.ax.set_title(f'{pid_info}_{eqc.title}') 1a

519 set_axis_label_size(eqc.ax) 1a

520 

521 ax = eqc.figure.axes[1] 1a

522 if channels is not None: 1a

523 chn_title = channels.get('title', None) 

524 plot_brain_regions(channels['atlas_id'], brain_regions=br, display=True, ax=ax, 

525 title=chn_title) 

526 set_axis_label_size(ax) 

527 else: 

528 remove_axis_outline(ax) 1a

529 else: 

530 from viewspikes.gui import viewephys # noqa 

531 eqcs.append(viewephys(butt, fs=fs, channels=channels, title='highpass', br=br)) 

532 

533 if destripe: 

534 dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels) 

535 eqcs.append(viewephys(dest, fs=fs, channels=channels, title='destripe', br=br)) 

536 eqcs.append(viewephys((butt - dest), fs=fs, channels=channels, title='difference', br=br)) 

537 

538 for eqc in eqcs: 

539 y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500)) 

540 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(164, 142, 35), label='outside') 

541 y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500)) 

542 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(255, 0, 0), label='noisy') 

543 y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500)) 

544 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(0, 0, 255), label='dead') 

545 

546 eqcs[0].ctrl.set_gain(eqc_gain) 

547 eqcs[0].resize(1960, 1200) 

548 eqcs[0].viewBox_seismic.setXRange(*eqc_xrange) 

549 eqcs[0].viewBox_seismic.setYRange(0, nc) 

550 eqcs[0].ctrl.propagate() 

551 

552 # display features 

553 fig, axs = plt.subplots(2, 2, sharex=True, figsize=[16, 9], tight_layout=True) 1a

554 fig.suptitle(title) 1a

555 axs[0, 0].plot(channel_features['rms_raw'] * 1e6) 1a

556 axs[0, 0].set(title='rms', xlabel='channel number', ylabel='rms (uV)', ylim=ylim_rms) 1a

557 

558 axs[1, 0].plot(channel_features['psd_hf']) 1a

559 axs[1, 0].plot(inoisy, np.minimum(channel_features['psd_hf'][inoisy], 0.0999), 'xr') 1a

560 axs[1, 0].set(title='PSD above 80% Nyquist', xlabel='channel number', ylabel='PSD (uV ** 2 / Hz)', ylim=ylim_psd_hf) 1a

561 axs[1, 0].legend = ['psd', 'noisy'] 1a

562 

563 axs[0, 1].plot(channel_features['xcor_hf']) 1a

564 axs[0, 1].plot(channel_features['xcor_lf']) 1a

565 

566 axs[0, 1].plot(idead, channel_features['xcor_hf'][idead], 'xb') 1a

567 axs[0, 1].plot(ioutside, channel_features['xcor_lf'][ioutside], 'xy') 1a

568 axs[0, 1].set(title='Similarity', xlabel='channel number', ylabel='', ylim=[-1.5, 0.5]) 1a

569 axs[0, 1].legend(['detrend', 'trend', 'dead', 'outside']) 1a

570 

571 fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz 1a

572 axs[1, 1].imshow(20 * np.log10(psd).T, extent=[0, nc - 1, fscale[0], fscale[-1]], origin='lower', aspect='auto', 1a

573 vmin=-50, vmax=-20) 

574 axs[1, 1].set(title='PSD', xlabel='channel number', ylabel="Frequency (Hz)") 1a

575 axs[1, 1].plot(idead, idead * 0 + fs / 4, 'xb') 1a

576 axs[1, 1].plot(inoisy, inoisy * 0 + fs / 4, 'xr') 1a

577 axs[1, 1].plot(ioutside, ioutside * 0 + fs / 4, 'xy') 1a

578 

579 if save_dir is not None: 1a

580 output_files = [Path(save_dir).joinpath(f"{title}.png")] 1a

581 fig.savefig(output_files[0]) 1a

582 for eqc in eqcs: 1a

583 if plot_backend == 'matplotlib': 1a

584 output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.title}.png")) 1a

585 eqc.figure.savefig(str(output_files[-1])) 1a

586 else: 

587 output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.windowTitle()}.png")) 

588 eqc.grab().save(str(output_files[-1])) 

589 return fig, eqcs, output_files 1a

590 else: 

591 return fig, eqcs 

592 

593 

594def raw_destripe(raw, fs, t0, i_plt, n_plt, 

595 fig=None, axs=None, savedir=None, detect_badch=True, 

596 SAMPLE_SKIP=200, DISPLAY_TIME=0.05, N_CHAN=384, 

597 MIN_X=-0.00011, MAX_X=0.00011): 

598 ''' 

599 :param raw: raw ephys data, Ns x Nc, x-axis: time (s), y-axis: channel 

600 :param fs: sampling freq (Hz) of the raw ephys data 

601 :param t0: time (s) of ephys sample beginning from session start 

602 :param i_plt: increment of plot to display image one (start from 0, has to be < n_plt) 

603 :param n_plt: total number of subplot on figure 

604 :param fig: figure handle 

605 :param axs: axis handle 

606 :param savedir: filename, including directory, to save figure to 

607 :param detect_badch: boolean, to detect or not bad channels 

608 :param SAMPLE_SKIP: number of samples to skip at origin of ephsy sample for display 

609 :param DISPLAY_TIME: time (s) to display 

610 :param N_CHAN: number of expected channels on the probe 

611 :param MIN_X: max voltage for color range 

612 :param MAX_X: min voltage for color range 

613 :return: fig, axs 

614 ''' 

615 

616 # Import 

617 from neurodsp import voltage 

618 from ibllib.plots import Density 

619 

620 # Init fig 

621 if fig is None or axs is None: 

622 fig, axs = plt.subplots(nrows=1, ncols=n_plt, figsize=(14, 5), gridspec_kw={'width_ratios': 4 * n_plt}) 

623 

624 if i_plt > len(axs) - 1: # Error 

625 raise ValueError(f'The given increment of subplot ({i_plt+1}) ' 

626 f'is larger than the total number of subplots ({len(axs)})') 

627 

628 [nc, ns] = raw.shape 

629 if nc == N_CHAN: 

630 destripe = voltage.destripe(raw, fs=fs) 

631 X = destripe[:, :int(DISPLAY_TIME * fs)].T 

632 Xs = X[SAMPLE_SKIP:].T # Remove artifact at beginning 

633 Tplot = Xs.shape[1] / fs 

634 

635 # PLOT RAW DATA 

636 d = Density(-Xs, fs=fs, taxis=1, ax=axs[i_plt], vmin=MIN_X, vmax=MAX_X) # noqa 

637 axs[i_plt].set_ylabel('') 

638 axs[i_plt].set_xlim((0, Tplot * 1e3)) 

639 axs[i_plt].set_ylim((0, nc)) 

640 

641 # Init title 

642 title_plt = f't0 = {int(t0 / 60)} min' 

643 

644 if detect_badch: 

645 # Detect and remove bad channels prior to spike detection 

646 labels, xfeats = voltage.detect_bad_channels(raw, fs) 

647 idx_badchan = np.where(labels != 0)[0] 

648 # Plot bad channels on raw data 

649 x, y = np.meshgrid(idx_badchan, np.linspace(0, Tplot * 1e3, 20)) 

650 axs[i_plt].plot(y.flatten(), x.flatten(), '.k', markersize=1) 

651 # Append title 

652 title_plt += f', n={len(idx_badchan)} bad ch' 

653 

654 # Set title 

655 axs[i_plt].title.set_text(title_plt) 

656 

657 else: 

658 axs[i_plt].title.set_text(f'CANNOT DESTRIPE, N CHAN = {nc}') 

659 

660 # Amend some axis style 

661 if i_plt > 0: 

662 axs[i_plt].set_yticklabels('') 

663 

664 # Fig layout 

665 fig.tight_layout() 

666 if savedir is not None: 

667 fig.savefig(fname=savedir) 

668 

669 return fig, axs 

670 

671 

672def dlc_qc_plot(session_path, one=None): 

673 """ 

674 Creates DLC QC plot. 

675 Data is searched first locally, then on Alyx. Panels that lack required data are skipped. 

676 

677 Required data to create all panels 

678 'raw_video_data/_iblrig_bodyCamera.raw.mp4', 

679 'raw_video_data/_iblrig_leftCamera.raw.mp4', 

680 'raw_video_data/_iblrig_rightCamera.raw.mp4', 

681 'alf/_ibl_bodyCamera.dlc.pqt', 

682 'alf/_ibl_leftCamera.dlc.pqt', 

683 'alf/_ibl_rightCamera.dlc.pqt', 

684 'alf/_ibl_bodyCamera.times.npy', 

685 'alf/_ibl_leftCamera.times.npy', 

686 'alf/_ibl_rightCamera.times.npy', 

687 'alf/_ibl_leftCamera.features.pqt', 

688 'alf/_ibl_rightCamera.features.pqt', 

689 'alf/rightROIMotionEnergy.position.npy', 

690 'alf/leftROIMotionEnergy.position.npy', 

691 'alf/bodyROIMotionEnergy.position.npy', 

692 'alf/_ibl_trials.choice.npy', 

693 'alf/_ibl_trials.feedbackType.npy', 

694 'alf/_ibl_trials.feedback_times.npy', 

695 'alf/_ibl_trials.stimOn_times.npy', 

696 'alf/_ibl_wheel.position.npy', 

697 'alf/_ibl_wheel.timestamps.npy', 

698 'alf/licks.times.npy', 

699 

700 :params session_path: Path to session data on disk 

701 :params one: ONE instance, if None is given, default ONE is instantiated 

702 :returns: Matplotlib figure 

703 """ 

704 

705 one = one or ONE() 1ca

706 # hack for running on cortexlab local server 

707 if one.alyx.base_url == 'https://alyx.cortexlab.net': 1ca

708 one = ONE(base_url='https://alyx.internationalbrainlab.org') 

709 data = {} 1ca

710 cams = ['left', 'right', 'body'] 1ca

711 session_path = Path(session_path) 1ca

712 

713 # Load data for each camera 

714 for cam in cams: 1ca

715 # Load a single frame for each video 

716 # Check if video data is available locally,if yes, load a single frame 

717 video_path = session_path.joinpath('raw_video_data', f'_iblrig_{cam}Camera.raw.mp4') 1ca

718 if video_path.exists(): 1ca

719 data[f'{cam}_frame'] = get_video_frame(video_path, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0] 1a

720 # If not, try to stream a frame (try three times) 

721 else: 

722 try: 1c

723 video_url = url_from_eid(one.path2eid(session_path), one=one)[cam] 1c

724 for tries in range(3): 

725 try: 

726 data[f'{cam}_frame'] = get_video_frame(video_url, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0] 

727 break 

728 except BaseException: 

729 if tries < 2: 

730 tries += 1 

731 logger.info(f"Streaming {cam} video failed, retrying x{tries}") 

732 time.sleep(30) 

733 else: 

734 logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.") 

735 data[f'{cam}_frame'] = None 

736 except KeyError: 1c

737 logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.") 1c

738 data[f'{cam}_frame'] = None 1c

739 # Other camera associated data 

740 for feat in ['dlc', 'times', 'features', 'ROIMotionEnergy']: 1ca

741 # Check locally first, then try to load from alyx, if nothing works, set to None 

742 if feat == 'features' and cam == 'body': # this doesn't exist for body cam 1ca

743 continue 1ca

744 local_file = list(session_path.joinpath('alf').glob(f'*{cam}Camera.{feat}*')) 1ca

745 if len(local_file) > 0: 1ca

746 data[f'{cam}_{feat}'] = alfio.load_file_content(local_file[0]) 1a

747 else: 

748 alyx_ds = [ds for ds in one.list_datasets(one.path2eid(session_path)) if f'{cam}Camera.{feat}' in ds] 1c

749 if len(alyx_ds) > 0: 1c

750 data[f'{cam}_{feat}'] = one.load_dataset(one.path2eid(session_path), alyx_ds[0]) 

751 else: 

752 logger.warning(f"Could not load _ibl_{cam}Camera.{feat} some plots have to be skipped.") 1c

753 data[f'{cam}_{feat}'] = None 1c

754 # Sometimes there is a file but the object is empty, set to None 

755 if data[f'{cam}_{feat}'] is not None and len(data[f'{cam}_{feat}']) == 0: 1ca

756 logger.warning(f"Object loaded from _ibl_{cam}Camera.{feat} is empty, some plots have to be skipped.") 

757 data[f'{cam}_{feat}'] = None 

758 

759 # If we have no frame and/or no DLC and/or no times for all cams, raise an error, something is really wrong 

760 assert any([data[f'{cam}_frame'] is not None for cam in cams]), "No camera data could be loaded, aborting." 1ca

761 assert any([data[f'{cam}_dlc'] is not None for cam in cams]), "No DLC data could be loaded, aborting." 1a

762 assert any([data[f'{cam}_times'] is not None for cam in cams]), "No camera times data could be loaded, aborting." 1a

763 

764 # Load session level data 

765 for alf_object in ['trials', 'wheel', 'licks']: 1a

766 try: 1a

767 data[f'{alf_object}'] = alfio.load_object(session_path.joinpath('alf'), alf_object) # load locally 1a

768 continue 1a

769 except ALFObjectNotFound: 

770 pass 

771 try: 

772 data[f'{alf_object}'] = one.load_object(one.path2eid(session_path), alf_object) # then try from alyx 

773 except ALFObjectNotFound: 

774 logger.warning(f"Could not load {alf_object} object, some plots have to be skipped.") 

775 data[f'{alf_object}'] = None 

776 

777 # Simplify and clean up trials data 

778 if data['trials']: 1a

779 data['trials'] = pd.DataFrame( 1a

780 {k: data['trials'][k] for k in ['stimOn_times', 'feedback_times', 'choice', 'feedbackType']}) 

781 # Discard nan events and too long trials 

782 data['trials'] = data['trials'].dropna() 1a

783 data['trials'] = data['trials'].drop( 1a

784 data['trials'][(data['trials']['feedback_times'] - data['trials']['stimOn_times']) > 10].index) 

785 

786 # Make a list of panels, if inputs are missing, instead input a text to display 

787 panels = [] 1a

788 # Panel A, B, C: Trace on frame 

789 for cam in cams: 1a

790 if data[f'{cam}_frame'] is not None and data[f'{cam}_dlc'] is not None: 1a

791 panels.append((plot_trace_on_frame, 1a

792 {'frame': data[f'{cam}_frame'], 'dlc_df': data[f'{cam}_dlc'], 'cam': cam})) 

793 else: 

794 panels.append((None, f'Data missing\n{cam.capitalize()} cam trace on frame')) 

795 

796 # If trials data is not there, we cannot plot any of the trial average plots, skip all remaining panels 

797 if data['trials'] is None: 1a

798 panels.extend([(None, 'No trial data,\ncannot compute trial avgs') for i in range(7)]) 

799 else: 

800 # Panel D: Motion energy 

801 camera_dict = {'left': {'motion_energy': data['left_ROIMotionEnergy'], 'times': data['left_times']}, 1a

802 'right': {'motion_energy': data['right_ROIMotionEnergy'], 'times': data['right_times']}, 

803 'body': {'motion_energy': data['body_ROIMotionEnergy'], 'times': data['body_times']}} 

804 for cam in ['left', 'right', 'body']: # Remove cameras where we don't have motion energy AND camera times 1a

805 if camera_dict[cam]['motion_energy'] is None or camera_dict[cam]['times'] is None: 1a

806 _ = camera_dict.pop(cam) 

807 if len(camera_dict) > 0: 1a

808 panels.append((plot_motion_energy_hist, {'camera_dict': camera_dict, 'trials_df': data['trials']})) 1a

809 else: 

810 panels.append((None, 'Data missing\nMotion energy')) 

811 

812 # Panel E: Wheel position 

813 if data['wheel']: 1a

814 panels.append((plot_wheel_position, {'wheel_position': data['wheel'].position, 1a

815 'wheel_time': data['wheel'].timestamps, 

816 'trials_df': data['trials']})) 

817 else: 

818 panels.append((None, 'Data missing\nWheel position')) 

819 

820 # Panel F, G: Paw speed and nose speed 

821 # Try if all data is there for left cam first, otherwise right 

822 for cam in ['left', 'right']: 1a

823 fail = False 1a

824 if (data[f'{cam}_dlc'] is not None and data[f'{cam}_times'] is not None 1a

825 and len(data[f'{cam}_times']) >= len(data[f'{cam}_dlc'])): 

826 break 1a

827 fail = True 

828 if not fail: 1a

829 paw = 'r' if cam == 'left' else 'l' 1a

830 panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'], 1a

831 'trials_df': data['trials'], 'feature': f'paw_{paw}', 'cam': cam})) 

832 panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'], 1a

833 'trials_df': data['trials'], 'feature': 'nose_tip', 'legend': False, 

834 'cam': cam})) 

835 else: 

836 panels.extend([(None, 'Data missing or corrupt\nSpeed histograms') for i in range(2)]) 

837 

838 # Panel H and I: Lick plots 

839 if data['licks'] and data['licks'].times.shape[0] > 0: 1a

840 panels.append((plot_lick_hist, {'lick_times': data['licks'].times, 'trials_df': data['trials']})) 1a

841 panels.append((plot_lick_raster, {'lick_times': data['licks'].times, 'trials_df': data['trials']})) 1a

842 else: 

843 panels.extend([(None, 'Data missing\nLicks plots') for i in range(2)]) 

844 

845 # Panel J: pupil plot 

846 # Try if all data is there for left cam first, otherwise right 

847 for cam in ['left', 'right']: 1a

848 fail = False 1a

849 if (data[f'{cam}_times'] is not None and data[f'{cam}_features'] is not None 1a

850 and len(data[f'{cam}_times']) >= len(data[f'{cam}_features']) 

851 and not np.all(np.isnan(data[f'{cam}_features'].pupilDiameter_smooth))): 

852 break 1a

853 fail = True 

854 if not fail: 1a

855 panels.append((plot_pupil_diameter_hist, 1a

856 {'pupil_diameter': data[f'{cam}_features'].pupilDiameter_smooth, 

857 'cam_times': data[f'{cam}_times'], 'trials_df': data['trials'], 'cam': cam})) 

858 else: 

859 panels.append((None, 'Data missing or corrupt\nPupil diameter')) 

860 

861 # Plotting 

862 plt.rcParams.update({'font.size': 10}) 1a

863 fig = plt.figure(figsize=(17, 10)) 1a

864 for i, panel in enumerate(panels): 1a

865 ax = plt.subplot(2, 5, i + 1) 1a

866 ax.text(-0.1, 1.15, ascii_uppercase[i], transform=ax.transAxes, fontsize=16, fontweight='bold') 1a

867 # Check if there was in issue with inputs, if yes, print the respective text 

868 if panel[0] is None: 1a

869 ax.text(.5, .5, panel[1], color='r', fontweight='bold', fontsize=12, horizontalalignment='center', 

870 verticalalignment='center', transform=ax.transAxes) 

871 plt.axis('off') 

872 else: 

873 try: 1a

874 panel[0](**panel[1]) 1a

875 except BaseException: 1a

876 logger.error(f'Error in {panel[0].__name__}\n' + traceback.format_exc()) 1a

877 ax.text(.5, .5, f'Error while plotting\n{panel[0].__name__}', color='r', fontweight='bold', 1a

878 fontsize=12, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes) 

879 plt.axis('off') 1a

880 plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 1a

881 

882 return fig 1a