Coverage for ibllib/plots/figures.py: 68%
511 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1"""
2Module that produces figures, usually for the extraction pipeline
3"""
4import logging
5import time
6from pathlib import Path
7import traceback
8from string import ascii_uppercase
10import numpy as np
11import pandas as pd
12import scipy.signal
13import matplotlib.pyplot as plt
15from neurodsp import voltage
16from ibllib.plots.snapshot import ReportSnapshotProbe, ReportSnapshot
17from one.api import ONE
18import one.alf.io as alfio
19from one.alf.exceptions import ALFObjectNotFound
20from ibllib.io.video import get_video_frame, url_from_eid
21import spikeglx
22import neuropixel
23from brainbox.plot import driftmap
24from brainbox.io.spikeglx import Streamer
25from brainbox.behavior.dlc import SAMPLING, plot_trace_on_frame, plot_wheel_position, plot_lick_hist, \
26 plot_lick_raster, plot_motion_energy_hist, plot_speed_hist, plot_pupil_diameter_hist
27from brainbox.ephys_plots import image_lfp_spectrum_plot, image_rms_plot, plot_brain_regions
28from brainbox.io.one import load_spike_sorting_fast
29from brainbox.behavior import training
30from iblutil.numerical import ismember
31from ibllib.plots.misc import Density
34logger = logging.getLogger(__name__)
37def set_axis_label_size(ax, labels=14, ticklabels=12, title=14, cmap=False):
38 """
39 Function to normalise size of all axis labels
40 :param ax:
41 :param labels:
42 :param ticklabels:
43 :param title:
44 :param cmap:
45 :return:
46 """
48 ax.xaxis.get_label().set_fontsize(labels) 1a
49 ax.yaxis.get_label().set_fontsize(labels) 1a
50 ax.tick_params(labelsize=ticklabels) 1a
51 ax.title.set_fontsize(title) 1a
53 if cmap: 1a
54 cbar = ax.images[-1].colorbar 1a
55 cbar.ax.tick_params(labelsize=ticklabels) 1a
56 cbar.ax.yaxis.get_label().set_fontsize(labels) 1a
59def remove_axis_outline(ax):
60 """
61 Function to remove outline of empty axis
62 :param ax:
63 :return:
64 """
65 ax.get_xaxis().set_visible(False) 1a
66 ax.get_yaxis().set_visible(False) 1a
67 ax.spines['right'].set_visible(False) 1a
68 ax.spines['top'].set_visible(False) 1a
69 ax.spines['bottom'].set_visible(False) 1a
70 ax.spines['left'].set_visible(False) 1a
73class BehaviourPlots(ReportSnapshot):
74 """
75 Behavioural plots
76 """
78 signature = {
79 'input_files': [
80 ('*trials.table.pqt', 'alf', True),
81 ],
82 'output_files': [
83 ('psychometric_curve.png', 'snapshot/behaviour', True),
84 ('chronometric_curve.png', 'snapshot/behaviour', True),
85 ('reaction_time_with_trials.png', 'snapshot/behaviour', True)
86 ]
87 }
89 def __init__(self, eid, session_path=None, one=None, **kwargs):
90 self.one = one 1a
91 self.eid = eid 1a
92 self.session_path = session_path or self.one.eid2path(self.eid) 1a
93 super(BehaviourPlots, self).__init__(self.session_path, self.eid, one=self.one, 1a
94 **kwargs)
95 self.output_directory = self.session_path.joinpath('snapshot', 'behaviour') 1a
96 self.output_directory.mkdir(exist_ok=True, parents=True) 1a
98 def _run(self):
100 output_files = [] 1a
101 trials = alfio.load_object(self.session_path.joinpath('alf'), 'trials') 1a
102 title = '_'.join(list(self.session_path.parts[-3:])) 1a
104 fig, ax = training.plot_psychometric(trials, title=title, figsize=(8, 6)) 1a
105 set_axis_label_size(ax) 1a
106 save_path = Path(self.output_directory).joinpath("psychometric_curve.png") 1a
107 output_files.append(save_path) 1a
108 fig.savefig(save_path) 1a
109 plt.close(fig) 1a
111 fig, ax = training.plot_reaction_time(trials, title=title, figsize=(8, 6)) 1a
112 set_axis_label_size(ax) 1a
113 save_path = Path(self.output_directory).joinpath("chronometric_curve.png") 1a
114 output_files.append(save_path) 1a
115 fig.savefig(save_path) 1a
116 plt.close(fig) 1a
118 fig, ax = training.plot_reaction_time_over_trials(trials, title=title, figsize=(8, 6)) 1a
119 set_axis_label_size(ax) 1a
120 save_path = Path(self.output_directory).joinpath("reaction_time_with_trials.png") 1a
121 output_files.append(save_path) 1a
122 fig.savefig(save_path) 1a
123 plt.close(fig) 1a
125 return output_files 1a
128# TODO put into histology and alignment pipeline
129class HistologySlices(ReportSnapshotProbe):
130 """
131 Plots coronal and sagittal slice showing electrode locations
132 """
134 def _run(self):
136 assert self.pid
137 assert self.brain_atlas
139 output_files = []
140 self.histology_status = self.get_histology_status()
141 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
143 if self.hist_lookup[self.histology_status] > 0:
144 fig = plt.figure(figsize=(12, 9))
145 gs = fig.add_gridspec(2, 2, width_ratios=[.95, .05])
146 ax1 = fig.add_subplot(gs[0, 0])
147 self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 1, ax=ax1)
148 ax1.scatter(electrodes['mlapdv'][:, 0] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r')
149 ax1.set_title(f"{self.pid_label}")
151 ax2 = fig.add_subplot(gs[1, 0])
152 self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 0, ax=ax2)
153 ax2.scatter(electrodes['mlapdv'][:, 1] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r')
155 ax3 = fig.add_subplot(gs[:, 1])
156 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=ax3,
157 title=self.histology_status)
159 save_path = Path(self.output_directory).joinpath("histology_slices.png")
160 output_files.append(save_path)
161 fig.savefig(save_path)
162 plt.close(fig)
164 return output_files
166 def get_probe_signature(self):
167 input_signature = [('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False),
168 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False),
169 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)]
170 output_signature = [('histology_slices.png', f'snapshot/{self.pname}', True)]
171 self.signature = {'input_files': input_signature, 'output_files': output_signature}
174class LfpPlots(ReportSnapshotProbe):
175 """
176 Plots LFP spectrum and LFP RMS plots
177 """
179 def _run(self):
181 assert self.pid 1a
183 output_files = [] 1a
185 if self.location != 'server': 1a
186 self.histology_status = self.get_histology_status()
187 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
189 # lfp spectrum
190 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a
191 lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysSpectralDensityLF', 1a
192 namespace='iblqc')
193 _, _, _ = image_lfp_spectrum_plot(lfp.power, lfp.freqs, clim=[-65, -95], fig_kwargs={'figsize': (8, 6)}, ax=axs[0], 1a
194 display=True, title=f"{self.pid_label}")
195 set_axis_label_size(axs[0], cmap=True) 1a
196 if self.histology_status: 1a
197 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1],
198 title=self.histology_status)
199 set_axis_label_size(axs[1])
200 else:
201 remove_axis_outline(axs[1]) 1a
203 save_path = Path(self.output_directory).joinpath("lfp_spectrum.png") 1a
204 output_files.append(save_path) 1a
205 fig.savefig(save_path) 1a
206 plt.close(fig) 1a
208 # lfp rms
209 # TODO need to figure out the clim range
210 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a
211 lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsLF', namespace='iblqc') 1a
212 _, _, _ = image_rms_plot(lfp.rms, lfp.timestamps, median_subtract=False, band='LFP', clim=[-35, -45], ax=axs[0], 1a
213 cmap='inferno', fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}")
214 set_axis_label_size(axs[0], cmap=True) 1a
215 if self.histology_status: 1a
216 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1],
217 title=self.histology_status)
218 set_axis_label_size(axs[1])
219 else:
220 remove_axis_outline(axs[1]) 1a
222 save_path = Path(self.output_directory).joinpath("lfp_rms.png") 1a
223 output_files.append(save_path) 1a
224 fig.savefig(save_path) 1a
225 plt.close(fig) 1a
227 return output_files 1a
229 def get_probe_signature(self):
230 input_signature = [('_iblqc_ephysTimeRmsLF.rms.npy', f'raw_ephys_data/{self.pname}', True), 1a
231 ('_iblqc_ephysTimeRmsLF.timestamps.npy', f'raw_ephys_data/{self.pname}', True),
232 ('_iblqc_ephysSpectralDensityLF.freqs.npy', f'raw_ephys_data/{self.pname}', True),
233 ('_iblqc_ephysSpectralDensityLF.power.npy', f'raw_ephys_data/{self.pname}', True),
234 ('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False),
235 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False),
236 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)]
237 output_signature = [('lfp_spectrum.png', f'snapshot/{self.pname}', True), 1a
238 ('lfp_rms.png', f'snapshot/{self.pname}', True)]
239 self.signature = {'input_files': input_signature, 'output_files': output_signature} 1a
242class ApPlots(ReportSnapshotProbe):
243 """
244 Plots AP RMS plots
245 """
247 def _run(self):
249 assert self.pid 1a
251 output_files = [] 1a
253 if self.location != 'server': 1a
254 self.histology_status = self.get_histology_status()
255 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
257 # TODO need to figure out the clim range
258 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a
259 ap = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsAP', namespace='iblqc') 1a
260 _, _, _ = image_rms_plot(ap.rms, ap.timestamps, median_subtract=False, band='AP', ax=axs[0], 1a
261 fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}")
262 set_axis_label_size(axs[0], cmap=True) 1a
263 if self.histology_status: 1a
264 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1],
265 title=self.histology_status)
266 set_axis_label_size(axs[1])
267 else:
268 remove_axis_outline(axs[1]) 1a
270 save_path = Path(self.output_directory).joinpath("ap_rms.png") 1a
271 output_files.append(save_path) 1a
272 fig.savefig(save_path) 1a
273 plt.close(fig) 1a
275 return output_files 1a
277 def get_probe_signature(self):
278 input_signature = [('_iblqc_ephysTimeRmsAP.rms.npy', f'raw_ephys_data/{self.pname}', True), 1a
279 ('_iblqc_ephysTimeRmsAP.timestamps.npy', f'raw_ephys_data/{self.pname}', True),
280 ('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False),
281 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False),
282 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)]
283 output_signature = [('ap_rms.png', f'snapshot/{self.pname}', True)] 1a
284 self.signature = {'input_files': input_signature, 'output_files': output_signature} 1a
287class SpikeSorting(ReportSnapshotProbe):
288 """
289 Plots raw electrophysiology AP band
290 :param session_path: session path
291 :param probe_id: str, UUID of the probe insertion for which to create the plot
292 :param **kwargs: keyword arguments passed to tasks.Task
293 """
295 def _run(self, collection=None):
296 """runs for initiated PID, streams data, destripe and check bad channels"""
298 def plot_driftmap(self, spikes, clusters, channels, collection): 1a
299 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a
300 driftmap(spikes.times, spikes.depths, t_bin=0.007, d_bin=10, vmax=0.5, ax=axs[0]) 1a
301 title_str = f"{self.pid_label}, {collection}, {self.pid} \n " \ 1a
302 f"{spikes.clusters.size:_} spikes, {clusters.depths.size:_} clusters"
303 ylim = (0, np.max(channels['axial_um'])) 1a
304 axs[0].set(ylim=ylim, title=title_str) 1a
305 run_label = str(Path(collection).relative_to(f'alf/{self.pname}')) 1a
306 run_label = "ks2matlab" if run_label == '.' else run_label 1a
307 outfile = self.output_directory.joinpath(f"spike_sorting_raster_{run_label}.png") 1a
308 set_axis_label_size(axs[0]) 1a
310 if self.histology_status: 1a
311 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'],
312 brain_regions=self.brain_regions, display=True, ax=axs[1], title=self.histology_status)
313 axs[1].set(ylim=ylim)
314 set_axis_label_size(axs[1])
315 else:
316 remove_axis_outline(axs[1]) 1a
318 fig.savefig(outfile) 1a
319 plt.close(fig) 1a
321 return outfile, fig, axs 1a
323 output_files = [] 1a
324 if self.location == 'server': 1a
325 assert collection 1a
326 spikes = alfio.load_object(self.session_path.joinpath(collection), 'spikes') 1a
327 clusters = alfio.load_object(self.session_path.joinpath(collection), 'clusters') 1a
328 channels = alfio.load_object(self.session_path.joinpath(collection), 'channels') 1a
329 channels['axial_um'] = channels['localCoordinates'][:, 1] 1a
331 out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection) 1a
332 output_files.append(out) 1a
334 else:
335 self.histology_status = self.get_histology_status()
336 all_here, output_files = self.assert_expected(self.output_files, silent=True)
337 spike_sorting_runs = self.one.list_datasets(self.eid, filename='spikes.times.npy', collection=f'alf/{self.pname}*')
338 if all_here and len(output_files) == len(spike_sorting_runs):
339 return output_files
340 logger.info(self.output_directory)
341 for run in spike_sorting_runs:
342 collection = str(Path(run).parent.as_posix())
343 spikes, clusters, channels = load_spike_sorting_fast(
344 eid=self.eid, probe=self.pname, one=self.one, nested=False, collection=collection,
345 dataset_types=['spikes.depths'], brain_regions=self.brain_regions)
347 if 'atlas_id' not in channels.keys():
348 channels = self.get_channels('channels', collection)
350 out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection)
351 output_files.append(out)
353 return output_files 1a
355 def get_probe_signature(self):
356 input_signature = [('spikes.times.npy', f'alf/{self.pname}*', True), 1a
357 ('spikes.amps.npy', f'alf/{self.pname}*', True),
358 ('spikes.depths.npy', f'alf/{self.pname}*', True),
359 ('clusters.depths.npy', f'alf/{self.pname}*', True),
360 ('channels.localCoordinates.npy', f'alf/{self.pname}*', False),
361 ('channels.mlapdv.npy', f'alf/{self.pname}*', False),
362 ('channels.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}*', False)]
363 output_signature = [('spike_sorting_raster*.png', f'snapshot/{self.pname}', True)] 1a
364 self.signature = {'input_files': input_signature, 'output_files': output_signature} 1a
366 def get_signatures(self, **kwargs):
367 files_spikes = Path(self.session_path).joinpath('alf').rglob('spikes.times.npy') 1a
368 folder_probes = [f.parent for f in files_spikes] 1a
370 full_input_files = [] 1a
371 for sig in self.signature['input_files']: 1a
372 for folder in folder_probes: 1a
373 full_input_files.append((sig[0], str(folder.relative_to(self.session_path)), sig[2])) 1a
374 if len(full_input_files) != 0: 1a
375 self.input_files = full_input_files 1a
376 else:
377 self.input_files = self.signature['input_files']
379 self.output_files = self.signature['output_files'] 1a
382class BadChannelsAp(ReportSnapshotProbe):
383 """
384 Plots raw electrophysiology AP band
385 task = BadChannelsAp(pid, one=one=one)
386 :param session_path: session path
387 :param probe_id: str, UUID of the probe insertion for which to create the plot
388 :param **kwargs: keyword arguments passed to tasks.Task
389 """
391 def get_probe_signature(self):
392 pname = self.pname 1a
393 input_signature = [('*ap.meta', f'raw_ephys_data/{pname}', True), 1a
394 ('*ap.ch', f'raw_ephys_data/{pname}', False)]
395 output_signature = [('raw_ephys_bad_channels.png', f'snapshot/{pname}', True), 1a
396 ('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True),
397 ('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True),
398 ('raw_ephys_bad_channels_destripe.png', f'snapshot/{pname}', True),
399 ('raw_ephys_bad_channels_difference.png', f'snapshot/{pname}', True),
400 ]
401 self.signature = {'input_files': input_signature, 'output_files': output_signature} 1a
403 def _run(self):
404 """runs for initiated PID, streams data, destripe and check bad channels"""
405 assert self.pid 1a
406 self.eqcs = [] 1a
407 T0 = 60 * 30 1a
408 SNAPSHOT_LABEL = "raw_ephys_bad_channels" 1a
409 output_files = list(self.output_directory.glob(f'{SNAPSHOT_LABEL}*')) 1a
410 if len(output_files) == 4: 1a
411 return output_files
413 self.output_directory.mkdir(exist_ok=True, parents=True) 1a
415 if self.location != 'server': 1a
416 self.histology_status = self.get_histology_status()
417 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
419 if 'atlas_id' in electrodes.keys():
420 electrodes['ibr'] = ismember(electrodes['atlas_id'], self.brain_regions.id)[1]
421 electrodes['acronym'] = self.brain_regions.acronym[electrodes['ibr']]
422 electrodes['name'] = self.brain_regions.name[electrodes['ibr']]
423 electrodes['title'] = self.histology_status
424 else:
425 electrodes = None
427 nsecs = 1
428 sr = Streamer(pid=self.pid, one=self.one, remove_cached=False, typ='ap')
429 s0 = T0 * sr.fs
430 tsel = slice(int(s0), int(s0) + int(nsecs * sr.fs))
431 # Important: remove sync channel from raw data, and transpose
432 raw = sr[tsel, :-sr.nsync].T
434 else:
435 electrodes = None 1a
436 ap_file = next(self.session_path.joinpath('raw_ephys_data', self.pname).glob('*ap.*bin'), None) 1a
437 if ap_file is not None: 1a
438 sr = spikeglx.Reader(ap_file) 1a
439 # If T0 is greater than recording length, take 500 sec before end
440 if sr.rl < T0: 1a
441 T0 = int(sr.rl - 500)
442 raw = sr[int((sr.fs * T0)):int((sr.fs * (T0 + 1))), :-sr.nsync].T 1a
443 else:
444 return []
446 if sr.meta.get('NP2.4_shank', None) is not None: 1a
447 h = neuropixel.trace_header(sr.major_version, nshank=4)
448 h = neuropixel.split_trace_header(h, shank=int(sr.meta.get('NP2.4_shank')))
449 else:
450 h = neuropixel.trace_header(sr.major_version, nshank=np.unique(sr.geometry['shank']).size) 1a
452 channel_labels, channel_features = voltage.detect_bad_channels(raw, sr.fs) 1a
453 _, eqcs, output_files = ephys_bad_channels( 1a
454 raw=raw, fs=sr.fs, channel_labels=channel_labels, channel_features=channel_features, h=h, channels=electrodes,
455 title=SNAPSHOT_LABEL, destripe=True, save_dir=self.output_directory, br=self.brain_regions, pid_info=self.pid_label)
456 self.eqcs = eqcs 1a
457 return output_files 1a
460def ephys_bad_channels(raw, fs, channel_labels, channel_features, h=None, channels=None, title="ephys_bad_channels",
461 save_dir=None, destripe=False, eqcs=None, br=None, pid_info=None, plot_backend='matplotlib'):
462 nc, ns = raw.shape 1a
463 rl = ns / fs 1a
465 def gain2level(gain): 1a
466 return 10 ** (gain / 20) * 4 * np.array([-1, 1]) 1a
468 if fs >= 2600: # AP band 1a
469 ylim_rms = [0, 100] 1a
470 ylim_psd_hf = [0, 0.1] 1a
471 eqc_xrange = [450, 500] 1a
472 butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'} 1a
473 eqc_gain = - 90 1a
474 eqc_levels = gain2level(eqc_gain) 1a
475 else:
476 # we are working with the LFP
477 ylim_rms = [0, 1000]
478 ylim_psd_hf = [0, 1]
479 eqc_xrange = [450, 950]
480 butter_kwargs = {'N': 3, 'Wn': np.array([2, 125]) / fs * 2, 'btype': 'bandpass'}
481 eqc_gain = - 78
482 eqc_levels = gain2level(eqc_gain)
484 inoisy = np.where(channel_labels == 2)[0] 1a
485 idead = np.where(channel_labels == 1)[0] 1a
486 ioutside = np.where(channel_labels == 3)[0] 1a
488 # display voltage traces
489 eqcs = [] if eqcs is None else eqcs 1a
490 # butterworth, for display only
491 sos = scipy.signal.butter(**butter_kwargs, output='sos') 1a
492 butt = scipy.signal.sosfiltfilt(sos, raw) 1a
494 if plot_backend == 'matplotlib': 1a
495 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a
496 eqcs.append(Density(butt, fs=fs, taxis=1, ax=axs[0], title='highpass', vmin=eqc_levels[0], vmax=eqc_levels[1])) 1a
498 if destripe: 1a
499 dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels) 1a
500 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a
501 eqcs.append(Density( 1a
502 dest, fs=fs, taxis=1, ax=axs[0], title='destripe', vmin=eqc_levels[0], vmax=eqc_levels[1]))
503 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9)) 1a
504 eqcs.append(Density((butt - dest), fs=fs, taxis=1, ax=axs[0], title='difference', vmin=eqc_levels[0], 1a
505 vmax=eqc_levels[1]))
507 for eqc in eqcs: 1a
508 y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500)) 1a
509 eqc.ax.scatter(x.flatten(), y.flatten(), c='goldenrod', s=4) 1a
510 y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500)) 1a
511 eqc.ax.scatter(x.flatten(), y.flatten(), c='r', s=4) 1a
512 y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500)) 1a
513 eqc.ax.scatter(x.flatten(), y.flatten(), c='b', s=4) 1a
515 eqc.ax.set_xlim(*eqc_xrange) 1a
516 eqc.ax.set_ylim(0, nc) 1a
517 eqc.ax.set_ylabel('Channel index') 1a
518 eqc.ax.set_title(f'{pid_info}_{eqc.title}') 1a
519 set_axis_label_size(eqc.ax) 1a
521 ax = eqc.figure.axes[1] 1a
522 if channels is not None: 1a
523 chn_title = channels.get('title', None)
524 plot_brain_regions(channels['atlas_id'], brain_regions=br, display=True, ax=ax,
525 title=chn_title)
526 set_axis_label_size(ax)
527 else:
528 remove_axis_outline(ax) 1a
529 else:
530 from viewspikes.gui import viewephys # noqa
531 eqcs.append(viewephys(butt, fs=fs, channels=channels, title='highpass', br=br))
533 if destripe:
534 dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels)
535 eqcs.append(viewephys(dest, fs=fs, channels=channels, title='destripe', br=br))
536 eqcs.append(viewephys((butt - dest), fs=fs, channels=channels, title='difference', br=br))
538 for eqc in eqcs:
539 y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500))
540 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(164, 142, 35), label='outside')
541 y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500))
542 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(255, 0, 0), label='noisy')
543 y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500))
544 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(0, 0, 255), label='dead')
546 eqcs[0].ctrl.set_gain(eqc_gain)
547 eqcs[0].resize(1960, 1200)
548 eqcs[0].viewBox_seismic.setXRange(*eqc_xrange)
549 eqcs[0].viewBox_seismic.setYRange(0, nc)
550 eqcs[0].ctrl.propagate()
552 # display features
553 fig, axs = plt.subplots(2, 2, sharex=True, figsize=[16, 9], tight_layout=True) 1a
554 fig.suptitle(title) 1a
555 axs[0, 0].plot(channel_features['rms_raw'] * 1e6) 1a
556 axs[0, 0].set(title='rms', xlabel='channel number', ylabel='rms (uV)', ylim=ylim_rms) 1a
558 axs[1, 0].plot(channel_features['psd_hf']) 1a
559 axs[1, 0].plot(inoisy, np.minimum(channel_features['psd_hf'][inoisy], 0.0999), 'xr') 1a
560 axs[1, 0].set(title='PSD above 80% Nyquist', xlabel='channel number', ylabel='PSD (uV ** 2 / Hz)', ylim=ylim_psd_hf) 1a
561 axs[1, 0].legend = ['psd', 'noisy'] 1a
563 axs[0, 1].plot(channel_features['xcor_hf']) 1a
564 axs[0, 1].plot(channel_features['xcor_lf']) 1a
566 axs[0, 1].plot(idead, channel_features['xcor_hf'][idead], 'xb') 1a
567 axs[0, 1].plot(ioutside, channel_features['xcor_lf'][ioutside], 'xy') 1a
568 axs[0, 1].set(title='Similarity', xlabel='channel number', ylabel='', ylim=[-1.5, 0.5]) 1a
569 axs[0, 1].legend(['detrend', 'trend', 'dead', 'outside']) 1a
571 fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz 1a
572 axs[1, 1].imshow(20 * np.log10(psd).T, extent=[0, nc - 1, fscale[0], fscale[-1]], origin='lower', aspect='auto', 1a
573 vmin=-50, vmax=-20)
574 axs[1, 1].set(title='PSD', xlabel='channel number', ylabel="Frequency (Hz)") 1a
575 axs[1, 1].plot(idead, idead * 0 + fs / 4, 'xb') 1a
576 axs[1, 1].plot(inoisy, inoisy * 0 + fs / 4, 'xr') 1a
577 axs[1, 1].plot(ioutside, ioutside * 0 + fs / 4, 'xy') 1a
579 if save_dir is not None: 1a
580 output_files = [Path(save_dir).joinpath(f"{title}.png")] 1a
581 fig.savefig(output_files[0]) 1a
582 for eqc in eqcs: 1a
583 if plot_backend == 'matplotlib': 1a
584 output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.title}.png")) 1a
585 eqc.figure.savefig(str(output_files[-1])) 1a
586 else:
587 output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.windowTitle()}.png"))
588 eqc.grab().save(str(output_files[-1]))
589 return fig, eqcs, output_files 1a
590 else:
591 return fig, eqcs
594def raw_destripe(raw, fs, t0, i_plt, n_plt,
595 fig=None, axs=None, savedir=None, detect_badch=True,
596 SAMPLE_SKIP=200, DISPLAY_TIME=0.05, N_CHAN=384,
597 MIN_X=-0.00011, MAX_X=0.00011):
598 '''
599 :param raw: raw ephys data, Ns x Nc, x-axis: time (s), y-axis: channel
600 :param fs: sampling freq (Hz) of the raw ephys data
601 :param t0: time (s) of ephys sample beginning from session start
602 :param i_plt: increment of plot to display image one (start from 0, has to be < n_plt)
603 :param n_plt: total number of subplot on figure
604 :param fig: figure handle
605 :param axs: axis handle
606 :param savedir: filename, including directory, to save figure to
607 :param detect_badch: boolean, to detect or not bad channels
608 :param SAMPLE_SKIP: number of samples to skip at origin of ephsy sample for display
609 :param DISPLAY_TIME: time (s) to display
610 :param N_CHAN: number of expected channels on the probe
611 :param MIN_X: max voltage for color range
612 :param MAX_X: min voltage for color range
613 :return: fig, axs
614 '''
616 # Import
617 from neurodsp import voltage
618 from ibllib.plots import Density
620 # Init fig
621 if fig is None or axs is None:
622 fig, axs = plt.subplots(nrows=1, ncols=n_plt, figsize=(14, 5), gridspec_kw={'width_ratios': 4 * n_plt})
624 if i_plt > len(axs) - 1: # Error
625 raise ValueError(f'The given increment of subplot ({i_plt+1}) '
626 f'is larger than the total number of subplots ({len(axs)})')
628 [nc, ns] = raw.shape
629 if nc == N_CHAN:
630 destripe = voltage.destripe(raw, fs=fs)
631 X = destripe[:, :int(DISPLAY_TIME * fs)].T
632 Xs = X[SAMPLE_SKIP:].T # Remove artifact at beginning
633 Tplot = Xs.shape[1] / fs
635 # PLOT RAW DATA
636 d = Density(-Xs, fs=fs, taxis=1, ax=axs[i_plt], vmin=MIN_X, vmax=MAX_X) # noqa
637 axs[i_plt].set_ylabel('')
638 axs[i_plt].set_xlim((0, Tplot * 1e3))
639 axs[i_plt].set_ylim((0, nc))
641 # Init title
642 title_plt = f't0 = {int(t0 / 60)} min'
644 if detect_badch:
645 # Detect and remove bad channels prior to spike detection
646 labels, xfeats = voltage.detect_bad_channels(raw, fs)
647 idx_badchan = np.where(labels != 0)[0]
648 # Plot bad channels on raw data
649 x, y = np.meshgrid(idx_badchan, np.linspace(0, Tplot * 1e3, 20))
650 axs[i_plt].plot(y.flatten(), x.flatten(), '.k', markersize=1)
651 # Append title
652 title_plt += f', n={len(idx_badchan)} bad ch'
654 # Set title
655 axs[i_plt].title.set_text(title_plt)
657 else:
658 axs[i_plt].title.set_text(f'CANNOT DESTRIPE, N CHAN = {nc}')
660 # Amend some axis style
661 if i_plt > 0:
662 axs[i_plt].set_yticklabels('')
664 # Fig layout
665 fig.tight_layout()
666 if savedir is not None:
667 fig.savefig(fname=savedir)
669 return fig, axs
672def dlc_qc_plot(session_path, one=None):
673 """
674 Creates DLC QC plot.
675 Data is searched first locally, then on Alyx. Panels that lack required data are skipped.
677 Required data to create all panels
678 'raw_video_data/_iblrig_bodyCamera.raw.mp4',
679 'raw_video_data/_iblrig_leftCamera.raw.mp4',
680 'raw_video_data/_iblrig_rightCamera.raw.mp4',
681 'alf/_ibl_bodyCamera.dlc.pqt',
682 'alf/_ibl_leftCamera.dlc.pqt',
683 'alf/_ibl_rightCamera.dlc.pqt',
684 'alf/_ibl_bodyCamera.times.npy',
685 'alf/_ibl_leftCamera.times.npy',
686 'alf/_ibl_rightCamera.times.npy',
687 'alf/_ibl_leftCamera.features.pqt',
688 'alf/_ibl_rightCamera.features.pqt',
689 'alf/rightROIMotionEnergy.position.npy',
690 'alf/leftROIMotionEnergy.position.npy',
691 'alf/bodyROIMotionEnergy.position.npy',
692 'alf/_ibl_trials.choice.npy',
693 'alf/_ibl_trials.feedbackType.npy',
694 'alf/_ibl_trials.feedback_times.npy',
695 'alf/_ibl_trials.stimOn_times.npy',
696 'alf/_ibl_wheel.position.npy',
697 'alf/_ibl_wheel.timestamps.npy',
698 'alf/licks.times.npy',
700 :params session_path: Path to session data on disk
701 :params one: ONE instance, if None is given, default ONE is instantiated
702 :returns: Matplotlib figure
703 """
705 one = one or ONE() 1ca
706 # hack for running on cortexlab local server
707 if one.alyx.base_url == 'https://alyx.cortexlab.net': 1ca
708 one = ONE(base_url='https://alyx.internationalbrainlab.org')
709 data = {} 1ca
710 cams = ['left', 'right', 'body'] 1ca
711 session_path = Path(session_path) 1ca
713 # Load data for each camera
714 for cam in cams: 1ca
715 # Load a single frame for each video
716 # Check if video data is available locally,if yes, load a single frame
717 video_path = session_path.joinpath('raw_video_data', f'_iblrig_{cam}Camera.raw.mp4') 1ca
718 if video_path.exists(): 1ca
719 data[f'{cam}_frame'] = get_video_frame(video_path, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0] 1a
720 # If not, try to stream a frame (try three times)
721 else:
722 try: 1c
723 video_url = url_from_eid(one.path2eid(session_path), one=one)[cam] 1c
724 for tries in range(3):
725 try:
726 data[f'{cam}_frame'] = get_video_frame(video_url, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0]
727 break
728 except BaseException:
729 if tries < 2:
730 tries += 1
731 logger.info(f"Streaming {cam} video failed, retrying x{tries}")
732 time.sleep(30)
733 else:
734 logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.")
735 data[f'{cam}_frame'] = None
736 except KeyError: 1c
737 logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.") 1c
738 data[f'{cam}_frame'] = None 1c
739 # Other camera associated data
740 for feat in ['dlc', 'times', 'features', 'ROIMotionEnergy']: 1ca
741 # Check locally first, then try to load from alyx, if nothing works, set to None
742 if feat == 'features' and cam == 'body': # this doesn't exist for body cam 1ca
743 continue 1ca
744 local_file = list(session_path.joinpath('alf').glob(f'*{cam}Camera.{feat}*')) 1ca
745 if len(local_file) > 0: 1ca
746 data[f'{cam}_{feat}'] = alfio.load_file_content(local_file[0]) 1a
747 else:
748 alyx_ds = [ds for ds in one.list_datasets(one.path2eid(session_path)) if f'{cam}Camera.{feat}' in ds] 1c
749 if len(alyx_ds) > 0: 1c
750 data[f'{cam}_{feat}'] = one.load_dataset(one.path2eid(session_path), alyx_ds[0])
751 else:
752 logger.warning(f"Could not load _ibl_{cam}Camera.{feat} some plots have to be skipped.") 1c
753 data[f'{cam}_{feat}'] = None 1c
754 # Sometimes there is a file but the object is empty, set to None
755 if data[f'{cam}_{feat}'] is not None and len(data[f'{cam}_{feat}']) == 0: 1ca
756 logger.warning(f"Object loaded from _ibl_{cam}Camera.{feat} is empty, some plots have to be skipped.")
757 data[f'{cam}_{feat}'] = None
759 # If we have no frame and/or no DLC and/or no times for all cams, raise an error, something is really wrong
760 assert any([data[f'{cam}_frame'] is not None for cam in cams]), "No camera data could be loaded, aborting." 1ca
761 assert any([data[f'{cam}_dlc'] is not None for cam in cams]), "No DLC data could be loaded, aborting." 1a
762 assert any([data[f'{cam}_times'] is not None for cam in cams]), "No camera times data could be loaded, aborting." 1a
764 # Load session level data
765 for alf_object in ['trials', 'wheel', 'licks']: 1a
766 try: 1a
767 data[f'{alf_object}'] = alfio.load_object(session_path.joinpath('alf'), alf_object) # load locally 1a
768 continue 1a
769 except ALFObjectNotFound:
770 pass
771 try:
772 data[f'{alf_object}'] = one.load_object(one.path2eid(session_path), alf_object) # then try from alyx
773 except ALFObjectNotFound:
774 logger.warning(f"Could not load {alf_object} object, some plots have to be skipped.")
775 data[f'{alf_object}'] = None
777 # Simplify and clean up trials data
778 if data['trials']: 1a
779 data['trials'] = pd.DataFrame( 1a
780 {k: data['trials'][k] for k in ['stimOn_times', 'feedback_times', 'choice', 'feedbackType']})
781 # Discard nan events and too long trials
782 data['trials'] = data['trials'].dropna() 1a
783 data['trials'] = data['trials'].drop( 1a
784 data['trials'][(data['trials']['feedback_times'] - data['trials']['stimOn_times']) > 10].index)
786 # Make a list of panels, if inputs are missing, instead input a text to display
787 panels = [] 1a
788 # Panel A, B, C: Trace on frame
789 for cam in cams: 1a
790 if data[f'{cam}_frame'] is not None and data[f'{cam}_dlc'] is not None: 1a
791 panels.append((plot_trace_on_frame, 1a
792 {'frame': data[f'{cam}_frame'], 'dlc_df': data[f'{cam}_dlc'], 'cam': cam}))
793 else:
794 panels.append((None, f'Data missing\n{cam.capitalize()} cam trace on frame'))
796 # If trials data is not there, we cannot plot any of the trial average plots, skip all remaining panels
797 if data['trials'] is None: 1a
798 panels.extend([(None, 'No trial data,\ncannot compute trial avgs') for i in range(7)])
799 else:
800 # Panel D: Motion energy
801 camera_dict = {'left': {'motion_energy': data['left_ROIMotionEnergy'], 'times': data['left_times']}, 1a
802 'right': {'motion_energy': data['right_ROIMotionEnergy'], 'times': data['right_times']},
803 'body': {'motion_energy': data['body_ROIMotionEnergy'], 'times': data['body_times']}}
804 for cam in ['left', 'right', 'body']: # Remove cameras where we don't have motion energy AND camera times 1a
805 if camera_dict[cam]['motion_energy'] is None or camera_dict[cam]['times'] is None: 1a
806 _ = camera_dict.pop(cam)
807 if len(camera_dict) > 0: 1a
808 panels.append((plot_motion_energy_hist, {'camera_dict': camera_dict, 'trials_df': data['trials']})) 1a
809 else:
810 panels.append((None, 'Data missing\nMotion energy'))
812 # Panel E: Wheel position
813 if data['wheel']: 1a
814 panels.append((plot_wheel_position, {'wheel_position': data['wheel'].position, 1a
815 'wheel_time': data['wheel'].timestamps,
816 'trials_df': data['trials']}))
817 else:
818 panels.append((None, 'Data missing\nWheel position'))
820 # Panel F, G: Paw speed and nose speed
821 # Try if all data is there for left cam first, otherwise right
822 for cam in ['left', 'right']: 1a
823 fail = False 1a
824 if (data[f'{cam}_dlc'] is not None and data[f'{cam}_times'] is not None 1a
825 and len(data[f'{cam}_times']) >= len(data[f'{cam}_dlc'])):
826 break 1a
827 fail = True
828 if not fail: 1a
829 paw = 'r' if cam == 'left' else 'l' 1a
830 panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'], 1a
831 'trials_df': data['trials'], 'feature': f'paw_{paw}', 'cam': cam}))
832 panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'], 1a
833 'trials_df': data['trials'], 'feature': 'nose_tip', 'legend': False,
834 'cam': cam}))
835 else:
836 panels.extend([(None, 'Data missing or corrupt\nSpeed histograms') for i in range(2)])
838 # Panel H and I: Lick plots
839 if data['licks'] and data['licks'].times.shape[0] > 0: 1a
840 panels.append((plot_lick_hist, {'lick_times': data['licks'].times, 'trials_df': data['trials']})) 1a
841 panels.append((plot_lick_raster, {'lick_times': data['licks'].times, 'trials_df': data['trials']})) 1a
842 else:
843 panels.extend([(None, 'Data missing\nLicks plots') for i in range(2)])
845 # Panel J: pupil plot
846 # Try if all data is there for left cam first, otherwise right
847 for cam in ['left', 'right']: 1a
848 fail = False 1a
849 if (data[f'{cam}_times'] is not None and data[f'{cam}_features'] is not None 1a
850 and len(data[f'{cam}_times']) >= len(data[f'{cam}_features'])
851 and not np.all(np.isnan(data[f'{cam}_features'].pupilDiameter_smooth))):
852 break 1a
853 fail = True
854 if not fail: 1a
855 panels.append((plot_pupil_diameter_hist, 1a
856 {'pupil_diameter': data[f'{cam}_features'].pupilDiameter_smooth,
857 'cam_times': data[f'{cam}_times'], 'trials_df': data['trials'], 'cam': cam}))
858 else:
859 panels.append((None, 'Data missing or corrupt\nPupil diameter'))
861 # Plotting
862 plt.rcParams.update({'font.size': 10}) 1a
863 fig = plt.figure(figsize=(17, 10)) 1a
864 for i, panel in enumerate(panels): 1a
865 ax = plt.subplot(2, 5, i + 1) 1a
866 ax.text(-0.1, 1.15, ascii_uppercase[i], transform=ax.transAxes, fontsize=16, fontweight='bold') 1a
867 # Check if there was in issue with inputs, if yes, print the respective text
868 if panel[0] is None: 1a
869 ax.text(.5, .5, panel[1], color='r', fontweight='bold', fontsize=12, horizontalalignment='center',
870 verticalalignment='center', transform=ax.transAxes)
871 plt.axis('off')
872 else:
873 try: 1a
874 panel[0](**panel[1]) 1a
875 except BaseException: 1a
876 logger.error(f'Error in {panel[0].__name__}\n' + traceback.format_exc()) 1a
877 ax.text(.5, .5, f'Error while plotting\n{panel[0].__name__}', color='r', fontweight='bold', 1a
878 fontsize=12, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
879 plt.axis('off') 1a
880 plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 1a
882 return fig 1a