Coverage for ibllib/plots/figures.py: 15%
517 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1"""
2Module that produces figures, usually for the extraction pipeline
3"""
4import logging
5import time
6from pathlib import Path
7import traceback
8from string import ascii_uppercase
10import numpy as np
11import pandas as pd
12import scipy.signal
13import matplotlib.pyplot as plt
15from ibldsp import voltage
16from ibllib.plots.snapshot import ReportSnapshotProbe, ReportSnapshot
17from one.api import ONE
18import one.alf.io as alfio
19from one.alf.exceptions import ALFObjectNotFound
20from ibllib.io.video import get_video_frame, url_from_eid
21import spikeglx
22import neuropixel
23from brainbox.plot import driftmap
24from brainbox.io.spikeglx import Streamer
25from brainbox.behavior.dlc import SAMPLING, plot_trace_on_frame, plot_wheel_position, plot_lick_hist, \
26 plot_lick_raster, plot_motion_energy_hist, plot_speed_hist, plot_pupil_diameter_hist
27from brainbox.ephys_plots import image_lfp_spectrum_plot, image_rms_plot, plot_brain_regions
28from brainbox.io.one import load_spike_sorting_fast
29from brainbox.behavior import training
30from iblutil.numerical import ismember
31from ibllib.plots.misc import Density
34logger = logging.getLogger(__name__)
37def set_axis_label_size(ax, labels=14, ticklabels=12, title=14, cmap=False):
38 """
39 Function to normalise size of all axis labels
40 :param ax:
41 :param labels:
42 :param ticklabels:
43 :param title:
44 :param cmap:
45 :return:
46 """
48 ax.xaxis.get_label().set_fontsize(labels)
49 ax.yaxis.get_label().set_fontsize(labels)
50 ax.tick_params(labelsize=ticklabels)
51 ax.title.set_fontsize(title)
53 if cmap:
54 cbar = ax.images[-1].colorbar
55 cbar.ax.tick_params(labelsize=ticklabels)
56 cbar.ax.yaxis.get_label().set_fontsize(labels)
59def remove_axis_outline(ax):
60 """
61 Function to remove outline of empty axis
62 :param ax:
63 :return:
64 """
65 ax.get_xaxis().set_visible(False)
66 ax.get_yaxis().set_visible(False)
67 ax.spines['right'].set_visible(False)
68 ax.spines['top'].set_visible(False)
69 ax.spines['bottom'].set_visible(False)
70 ax.spines['left'].set_visible(False)
73class BehaviourPlots(ReportSnapshot):
74 """Behavioural plots."""
76 @property
77 def signature(self):
78 signature = {
79 'input_files': [
80 ('*trials.table.pqt', self.trials_collection, True),
81 ],
82 'output_files': [
83 ('psychometric_curve.png', 'snapshot/behaviour', True),
84 ('chronometric_curve.png', 'snapshot/behaviour', True),
85 ('reaction_time_with_trials.png', 'snapshot/behaviour', True)
86 ]
87 }
88 return signature
90 def __init__(self, eid, session_path=None, one=None, **kwargs):
91 """
92 Generate and upload behaviour plots.
94 Parameters
95 ----------
96 eid : str, uuid.UUID
97 An experiment UUID.
98 session_path : pathlib.Path
99 A session path.
100 one : one.api.One
101 An instance of ONE for registration to Alyx.
102 trials_collection : str
103 The location of the trials data (default: 'alf').
104 kwargs
105 Arguments for ReportSnapshot constructor.
106 """
107 self.one = one
108 self.eid = eid
109 self.session_path = session_path or self.one.eid2path(self.eid)
110 self.trials_collection = kwargs.pop('task_collection', 'alf')
111 super(BehaviourPlots, self).__init__(self.session_path, self.eid, one=self.one,
112 **kwargs)
113 # Output directory should mirror trials collection, sans 'alf' part
114 self.output_directory = self.session_path.joinpath(
115 'snapshot', 'behaviour', self.trials_collection.removeprefix('alf').strip('/'))
116 self.output_directory.mkdir(exist_ok=True, parents=True)
118 def _run(self):
120 output_files = []
121 trials = alfio.load_object(self.session_path.joinpath(self.trials_collection), 'trials')
122 if self.one:
123 title = self.one.path2ref(self.session_path, as_dict=False)
124 else:
125 title = '_'.join(list(self.session_path.parts[-3:]))
127 fig, ax = training.plot_psychometric(trials, title=title, figsize=(8, 6))
128 set_axis_label_size(ax)
129 save_path = Path(self.output_directory).joinpath("psychometric_curve.png")
130 output_files.append(save_path)
131 fig.savefig(save_path)
132 plt.close(fig)
134 fig, ax = training.plot_reaction_time(trials, title=title, figsize=(8, 6))
135 set_axis_label_size(ax)
136 save_path = Path(self.output_directory).joinpath("chronometric_curve.png")
137 output_files.append(save_path)
138 fig.savefig(save_path)
139 plt.close(fig)
141 fig, ax = training.plot_reaction_time_over_trials(trials, title=title, figsize=(8, 6))
142 set_axis_label_size(ax)
143 save_path = Path(self.output_directory).joinpath("reaction_time_with_trials.png")
144 output_files.append(save_path)
145 fig.savefig(save_path)
146 plt.close(fig)
148 return output_files
151# TODO put into histology and alignment pipeline
152class HistologySlices(ReportSnapshotProbe):
153 """Plots coronal and sagittal slice showing electrode locations."""
155 def _run(self):
157 assert self.pid
158 assert self.brain_atlas
160 output_files = []
161 self.histology_status = self.get_histology_status()
162 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
164 if self.hist_lookup[self.histology_status] > 0:
165 fig = plt.figure(figsize=(12, 9))
166 gs = fig.add_gridspec(2, 2, width_ratios=[.95, .05])
167 ax1 = fig.add_subplot(gs[0, 0])
168 self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 1, ax=ax1)
169 ax1.scatter(electrodes['mlapdv'][:, 0] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r')
170 ax1.set_title(f"{self.pid_label}")
172 ax2 = fig.add_subplot(gs[1, 0])
173 self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 0, ax=ax2)
174 ax2.scatter(electrodes['mlapdv'][:, 1] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r')
176 ax3 = fig.add_subplot(gs[:, 1])
177 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=ax3,
178 title=self.histology_status)
180 save_path = Path(self.output_directory).joinpath("histology_slices.png")
181 output_files.append(save_path)
182 fig.savefig(save_path)
183 plt.close(fig)
185 return output_files
187 def get_probe_signature(self):
188 input_signature = [('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False),
189 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False),
190 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)]
191 output_signature = [('histology_slices.png', f'snapshot/{self.pname}', True)]
192 self.signature = {'input_files': input_signature, 'output_files': output_signature}
195class LfpPlots(ReportSnapshotProbe):
196 """
197 Plots LFP spectrum and LFP RMS plots
198 """
200 def _run(self):
202 assert self.pid
204 output_files = []
206 if self.location != 'server':
207 self.histology_status = self.get_histology_status()
208 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
210 # lfp spectrum
211 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
212 lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysSpectralDensityLF',
213 namespace='iblqc')
214 _, _, _ = image_lfp_spectrum_plot(lfp.power, lfp.freqs, clim=[-65, -95], fig_kwargs={'figsize': (8, 6)}, ax=axs[0],
215 display=True, title=f"{self.pid_label}")
216 set_axis_label_size(axs[0], cmap=True)
217 if self.histology_status:
218 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1],
219 title=self.histology_status)
220 set_axis_label_size(axs[1])
221 else:
222 remove_axis_outline(axs[1])
224 save_path = Path(self.output_directory).joinpath("lfp_spectrum.png")
225 output_files.append(save_path)
226 fig.savefig(save_path)
227 plt.close(fig)
229 # lfp rms
230 # TODO need to figure out the clim range
231 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
232 lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsLF', namespace='iblqc')
233 _, _, _ = image_rms_plot(lfp.rms, lfp.timestamps, median_subtract=False, band='LFP', clim=[-35, -45], ax=axs[0],
234 cmap='inferno', fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}")
235 set_axis_label_size(axs[0], cmap=True)
236 if self.histology_status:
237 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1],
238 title=self.histology_status)
239 set_axis_label_size(axs[1])
240 else:
241 remove_axis_outline(axs[1])
243 save_path = Path(self.output_directory).joinpath("lfp_rms.png")
244 output_files.append(save_path)
245 fig.savefig(save_path)
246 plt.close(fig)
248 return output_files
250 def get_probe_signature(self):
251 input_signature = [('_iblqc_ephysTimeRmsLF.rms.npy', f'raw_ephys_data/{self.pname}', True),
252 ('_iblqc_ephysTimeRmsLF.timestamps.npy', f'raw_ephys_data/{self.pname}', True),
253 ('_iblqc_ephysSpectralDensityLF.freqs.npy', f'raw_ephys_data/{self.pname}', True),
254 ('_iblqc_ephysSpectralDensityLF.power.npy', f'raw_ephys_data/{self.pname}', True),
255 ('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False),
256 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False),
257 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)]
258 output_signature = [('lfp_spectrum.png', f'snapshot/{self.pname}', True),
259 ('lfp_rms.png', f'snapshot/{self.pname}', True)]
260 self.signature = {'input_files': input_signature, 'output_files': output_signature}
263class ApPlots(ReportSnapshotProbe):
264 """
265 Plots AP RMS plots
266 """
268 def _run(self):
270 assert self.pid
272 output_files = []
274 if self.location != 'server':
275 self.histology_status = self.get_histology_status()
276 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
278 # TODO need to figure out the clim range
279 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
280 ap = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsAP', namespace='iblqc')
281 _, _, _ = image_rms_plot(ap.rms, ap.timestamps, median_subtract=False, band='AP', ax=axs[0],
282 fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}")
283 set_axis_label_size(axs[0], cmap=True)
284 if self.histology_status:
285 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1],
286 title=self.histology_status)
287 set_axis_label_size(axs[1])
288 else:
289 remove_axis_outline(axs[1])
291 save_path = Path(self.output_directory).joinpath("ap_rms.png")
292 output_files.append(save_path)
293 fig.savefig(save_path)
294 plt.close(fig)
296 return output_files
298 def get_probe_signature(self):
299 input_signature = [('_iblqc_ephysTimeRmsAP.rms.npy', f'raw_ephys_data/{self.pname}', True),
300 ('_iblqc_ephysTimeRmsAP.timestamps.npy', f'raw_ephys_data/{self.pname}', True),
301 ('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False),
302 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False),
303 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)]
304 output_signature = [('ap_rms.png', f'snapshot/{self.pname}', True)]
305 self.signature = {'input_files': input_signature, 'output_files': output_signature}
308class SpikeSorting(ReportSnapshotProbe):
309 """
310 Plots raw electrophysiology AP band
311 :param session_path: session path
312 :param probe_id: str, UUID of the probe insertion for which to create the plot
313 :param **kwargs: keyword arguments passed to tasks.Task
314 """
316 def _run(self, collection=None):
317 """runs for initiated PID, streams data, destripe and check bad channels"""
319 def plot_driftmap(self, spikes, clusters, channels, collection):
320 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
321 driftmap(spikes.times, spikes.depths, t_bin=0.007, d_bin=10, vmax=0.5, ax=axs[0])
322 title_str = f"{self.pid_label}, {collection}, {self.pid} \n " \
323 f"{spikes.clusters.size:_} spikes, {clusters.depths.size:_} clusters"
324 ylim = (0, np.max(channels['axial_um']))
325 axs[0].set(ylim=ylim, title=title_str)
326 run_label = str(Path(collection).relative_to(f'alf/{self.pname}'))
327 run_label = "ks2matlab" if run_label == '.' else run_label
328 outfile = self.output_directory.joinpath(f"spike_sorting_raster_{run_label}.png")
329 set_axis_label_size(axs[0])
331 if self.histology_status:
332 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'],
333 brain_regions=self.brain_regions, display=True, ax=axs[1], title=self.histology_status)
334 axs[1].set(ylim=ylim)
335 set_axis_label_size(axs[1])
336 else:
337 remove_axis_outline(axs[1])
339 fig.savefig(outfile)
340 plt.close(fig)
342 return outfile, fig, axs
344 output_files = []
345 if self.location == 'server':
346 assert collection
347 spikes = alfio.load_object(self.session_path.joinpath(collection), 'spikes')
348 clusters = alfio.load_object(self.session_path.joinpath(collection), 'clusters')
349 channels = alfio.load_object(self.session_path.joinpath(collection), 'channels')
350 channels['axial_um'] = channels['localCoordinates'][:, 1]
352 out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection)
353 output_files.append(out)
355 else:
356 self.histology_status = self.get_histology_status()
357 all_here, output_files = self.assert_expected(self.output_files, silent=True)
358 spike_sorting_runs = self.one.list_datasets(self.eid, filename='spikes.times.npy', collection=f'alf/{self.pname}*')
359 if all_here and len(output_files) == len(spike_sorting_runs):
360 return output_files
361 logger.info(self.output_directory)
362 for run in spike_sorting_runs:
363 collection = str(Path(run).parent.as_posix())
364 spikes, clusters, channels = load_spike_sorting_fast(
365 eid=self.eid, probe=self.pname, one=self.one, nested=False, collection=collection,
366 dataset_types=['spikes.depths'], brain_regions=self.brain_regions)
368 if 'atlas_id' not in channels.keys():
369 channels = self.get_channels('channels', collection)
371 out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection)
372 output_files.append(out)
374 return output_files
376 def get_probe_signature(self):
377 input_signature = [('spikes.times.npy', f'alf/{self.pname}*', True),
378 ('spikes.amps.npy', f'alf/{self.pname}*', True),
379 ('spikes.depths.npy', f'alf/{self.pname}*', True),
380 ('clusters.depths.npy', f'alf/{self.pname}*', True),
381 ('channels.localCoordinates.npy', f'alf/{self.pname}*', False),
382 ('channels.mlapdv.npy', f'alf/{self.pname}*', False),
383 ('channels.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}*', False)]
384 output_signature = [('spike_sorting_raster*.png', f'snapshot/{self.pname}', True)]
385 self.signature = {'input_files': input_signature, 'output_files': output_signature}
387 def get_signatures(self, **kwargs):
388 files_spikes = Path(self.session_path).joinpath('alf').rglob('spikes.times.npy')
389 folder_probes = [f.parent for f in files_spikes]
391 full_input_files = []
392 for sig in self.signature['input_files']:
393 for folder in folder_probes:
394 full_input_files.append((sig[0], str(folder.relative_to(self.session_path)), sig[2]))
395 if len(full_input_files) != 0:
396 self.input_files = full_input_files
397 else:
398 self.input_files = self.signature['input_files']
400 self.output_files = self.signature['output_files']
403class BadChannelsAp(ReportSnapshotProbe):
404 """
405 Plots raw electrophysiology AP band
406 task = BadChannelsAp(pid, one=one=one)
407 :param session_path: session path
408 :param probe_id: str, UUID of the probe insertion for which to create the plot
409 :param **kwargs: keyword arguments passed to tasks.Task
410 """
412 def get_probe_signature(self):
413 pname = self.pname
414 input_signature = [('*ap.meta', f'raw_ephys_data/{pname}', True),
415 ('*ap.ch', f'raw_ephys_data/{pname}', False)]
416 output_signature = [('raw_ephys_bad_channels.png', f'snapshot/{pname}', True),
417 ('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True),
418 ('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True),
419 ('raw_ephys_bad_channels_destripe.png', f'snapshot/{pname}', True),
420 ('raw_ephys_bad_channels_difference.png', f'snapshot/{pname}', True),
421 ]
422 self.signature = {'input_files': input_signature, 'output_files': output_signature}
424 def _run(self):
425 """runs for initiated PID, streams data, destripe and check bad channels"""
426 assert self.pid
427 self.eqcs = []
428 T0 = 60 * 30
429 SNAPSHOT_LABEL = "raw_ephys_bad_channels"
430 output_files = list(self.output_directory.glob(f'{SNAPSHOT_LABEL}*'))
431 if len(output_files) == 4:
432 return output_files
434 self.output_directory.mkdir(exist_ok=True, parents=True)
436 if self.location != 'server':
437 self.histology_status = self.get_histology_status()
438 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
440 if 'atlas_id' in electrodes.keys():
441 electrodes['ibr'] = ismember(electrodes['atlas_id'], self.brain_regions.id)[1]
442 electrodes['acronym'] = self.brain_regions.acronym[electrodes['ibr']]
443 electrodes['name'] = self.brain_regions.name[electrodes['ibr']]
444 electrodes['title'] = self.histology_status
445 else:
446 electrodes = None
448 nsecs = 1
449 sr = Streamer(pid=self.pid, one=self.one, remove_cached=False, typ='ap')
450 s0 = T0 * sr.fs
451 tsel = slice(int(s0), int(s0) + int(nsecs * sr.fs))
452 # Important: remove sync channel from raw data, and transpose
453 raw = sr[tsel, :-sr.nsync].T
455 else:
456 electrodes = None
457 ap_file = next(self.session_path.joinpath('raw_ephys_data', self.pname).glob('*ap.*bin'), None)
458 if ap_file is not None:
459 sr = spikeglx.Reader(ap_file)
460 # If T0 is greater than recording length, take 500 sec before end
461 if sr.rl < T0:
462 T0 = int(sr.rl - 500)
463 raw = sr[int((sr.fs * T0)):int((sr.fs * (T0 + 1))), :-sr.nsync].T
464 else:
465 return []
467 if sr.meta.get('NP2.4_shank', None) is not None:
468 h = neuropixel.trace_header(sr.major_version, nshank=4)
469 h = neuropixel.split_trace_header(h, shank=int(sr.meta.get('NP2.4_shank')))
470 else:
471 h = neuropixel.trace_header(sr.major_version, nshank=np.unique(sr.geometry['shank']).size)
473 channel_labels, channel_features = voltage.detect_bad_channels(raw, sr.fs)
474 _, eqcs, output_files = ephys_bad_channels(
475 raw=raw, fs=sr.fs, channel_labels=channel_labels, channel_features=channel_features, h=h, channels=electrodes,
476 title=SNAPSHOT_LABEL, destripe=True, save_dir=self.output_directory, br=self.brain_regions, pid_info=self.pid_label)
477 self.eqcs = eqcs
478 return output_files
481def ephys_bad_channels(raw, fs, channel_labels, channel_features, h=None, channels=None, title="ephys_bad_channels",
482 save_dir=None, destripe=False, eqcs=None, br=None, pid_info=None, plot_backend='matplotlib'):
483 nc, ns = raw.shape
484 rl = ns / fs
486 def gain2level(gain):
487 return 10 ** (gain / 20) * 4 * np.array([-1, 1])
489 if fs >= 2600: # AP band
490 ylim_rms = [0, 100]
491 ylim_psd_hf = [0, 0.1]
492 eqc_xrange = [450, 500]
493 butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'}
494 eqc_gain = - 90
495 eqc_levels = gain2level(eqc_gain)
496 else:
497 # we are working with the LFP
498 ylim_rms = [0, 1000]
499 ylim_psd_hf = [0, 1]
500 eqc_xrange = [450, 950]
501 butter_kwargs = {'N': 3, 'Wn': np.array([2, 125]) / fs * 2, 'btype': 'bandpass'}
502 eqc_gain = - 78
503 eqc_levels = gain2level(eqc_gain)
505 inoisy = np.where(channel_labels == 2)[0]
506 idead = np.where(channel_labels == 1)[0]
507 ioutside = np.where(channel_labels == 3)[0]
509 # display voltage traces
510 eqcs = [] if eqcs is None else eqcs
511 # butterworth, for display only
512 sos = scipy.signal.butter(**butter_kwargs, output='sos')
513 butt = scipy.signal.sosfiltfilt(sos, raw)
515 if plot_backend == 'matplotlib':
516 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
517 eqcs.append(Density(butt, fs=fs, taxis=1, ax=axs[0], title='highpass', vmin=eqc_levels[0], vmax=eqc_levels[1]))
519 if destripe:
520 dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels)
521 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
522 eqcs.append(Density(
523 dest, fs=fs, taxis=1, ax=axs[0], title='destripe', vmin=eqc_levels[0], vmax=eqc_levels[1]))
524 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
525 eqcs.append(Density((butt - dest), fs=fs, taxis=1, ax=axs[0], title='difference', vmin=eqc_levels[0],
526 vmax=eqc_levels[1]))
528 for eqc in eqcs:
529 y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500))
530 eqc.ax.scatter(x.flatten(), y.flatten(), c='goldenrod', s=4)
531 y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500))
532 eqc.ax.scatter(x.flatten(), y.flatten(), c='r', s=4)
533 y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500))
534 eqc.ax.scatter(x.flatten(), y.flatten(), c='b', s=4)
536 eqc.ax.set_xlim(*eqc_xrange)
537 eqc.ax.set_ylim(0, nc)
538 eqc.ax.set_ylabel('Channel index')
539 eqc.ax.set_title(f'{pid_info}_{eqc.title}')
540 set_axis_label_size(eqc.ax)
542 ax = eqc.figure.axes[1]
543 if channels is not None:
544 chn_title = channels.get('title', None)
545 plot_brain_regions(channels['atlas_id'], brain_regions=br, display=True, ax=ax,
546 title=chn_title)
547 set_axis_label_size(ax)
548 else:
549 remove_axis_outline(ax)
550 else:
551 from viewspikes.gui import viewephys # noqa
552 eqcs.append(viewephys(butt, fs=fs, channels=channels, title='highpass', br=br))
554 if destripe:
555 dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels)
556 eqcs.append(viewephys(dest, fs=fs, channels=channels, title='destripe', br=br))
557 eqcs.append(viewephys((butt - dest), fs=fs, channels=channels, title='difference', br=br))
559 for eqc in eqcs:
560 y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500))
561 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(164, 142, 35), label='outside')
562 y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500))
563 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(255, 0, 0), label='noisy')
564 y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500))
565 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(0, 0, 255), label='dead')
567 eqcs[0].ctrl.set_gain(eqc_gain)
568 eqcs[0].resize(1960, 1200)
569 eqcs[0].viewBox_seismic.setXRange(*eqc_xrange)
570 eqcs[0].viewBox_seismic.setYRange(0, nc)
571 eqcs[0].ctrl.propagate()
573 # display features
574 fig, axs = plt.subplots(2, 2, sharex=True, figsize=[16, 9], tight_layout=True)
575 fig.suptitle(title)
576 axs[0, 0].plot(channel_features['rms_raw'] * 1e6)
577 axs[0, 0].set(title='rms', xlabel='channel number', ylabel='rms (uV)', ylim=ylim_rms)
579 axs[1, 0].plot(channel_features['psd_hf'])
580 axs[1, 0].plot(inoisy, np.minimum(channel_features['psd_hf'][inoisy], 0.0999), 'xr')
581 axs[1, 0].set(title='PSD above 80% Nyquist', xlabel='channel number', ylabel='PSD (uV ** 2 / Hz)', ylim=ylim_psd_hf)
582 axs[1, 0].legend = ['psd', 'noisy']
584 axs[0, 1].plot(channel_features['xcor_hf'])
585 axs[0, 1].plot(channel_features['xcor_lf'])
587 axs[0, 1].plot(idead, channel_features['xcor_hf'][idead], 'xb')
588 axs[0, 1].plot(ioutside, channel_features['xcor_lf'][ioutside], 'xy')
589 axs[0, 1].set(title='Similarity', xlabel='channel number', ylabel='', ylim=[-1.5, 0.5])
590 axs[0, 1].legend(['detrend', 'trend', 'dead', 'outside'])
592 fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz
593 axs[1, 1].imshow(20 * np.log10(psd).T, extent=[0, nc - 1, fscale[0], fscale[-1]], origin='lower', aspect='auto',
594 vmin=-50, vmax=-20)
595 axs[1, 1].set(title='PSD', xlabel='channel number', ylabel="Frequency (Hz)")
596 axs[1, 1].plot(idead, idead * 0 + fs / 4, 'xb')
597 axs[1, 1].plot(inoisy, inoisy * 0 + fs / 4, 'xr')
598 axs[1, 1].plot(ioutside, ioutside * 0 + fs / 4, 'xy')
600 if save_dir is not None:
601 output_files = [Path(save_dir).joinpath(f"{title}.png")]
602 fig.savefig(output_files[0])
603 for eqc in eqcs:
604 if plot_backend == 'matplotlib':
605 output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.title}.png"))
606 eqc.figure.savefig(str(output_files[-1]))
607 else:
608 output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.windowTitle()}.png"))
609 eqc.grab().save(str(output_files[-1]))
610 return fig, eqcs, output_files
611 else:
612 return fig, eqcs
615def raw_destripe(raw, fs, t0, i_plt, n_plt,
616 fig=None, axs=None, savedir=None, detect_badch=True,
617 SAMPLE_SKIP=200, DISPLAY_TIME=0.05, N_CHAN=384,
618 MIN_X=-0.00011, MAX_X=0.00011):
619 '''
620 :param raw: raw ephys data, Ns x Nc, x-axis: time (s), y-axis: channel
621 :param fs: sampling freq (Hz) of the raw ephys data
622 :param t0: time (s) of ephys sample beginning from session start
623 :param i_plt: increment of plot to display image one (start from 0, has to be < n_plt)
624 :param n_plt: total number of subplot on figure
625 :param fig: figure handle
626 :param axs: axis handle
627 :param savedir: filename, including directory, to save figure to
628 :param detect_badch: boolean, to detect or not bad channels
629 :param SAMPLE_SKIP: number of samples to skip at origin of ephsy sample for display
630 :param DISPLAY_TIME: time (s) to display
631 :param N_CHAN: number of expected channels on the probe
632 :param MIN_X: max voltage for color range
633 :param MAX_X: min voltage for color range
634 :return: fig, axs
635 '''
637 # Import
638 from ibldsp import voltage
639 from ibllib.plots import Density
641 # Init fig
642 if fig is None or axs is None:
643 fig, axs = plt.subplots(nrows=1, ncols=n_plt, figsize=(14, 5), gridspec_kw={'width_ratios': 4 * n_plt})
645 if i_plt > len(axs) - 1: # Error
646 raise ValueError(f'The given increment of subplot ({i_plt+1}) '
647 f'is larger than the total number of subplots ({len(axs)})')
649 [nc, ns] = raw.shape
650 if nc == N_CHAN:
651 destripe = voltage.destripe(raw, fs=fs)
652 X = destripe[:, :int(DISPLAY_TIME * fs)].T
653 Xs = X[SAMPLE_SKIP:].T # Remove artifact at beginning
654 Tplot = Xs.shape[1] / fs
656 # PLOT RAW DATA
657 d = Density(-Xs, fs=fs, taxis=1, ax=axs[i_plt], vmin=MIN_X, vmax=MAX_X) # noqa
658 axs[i_plt].set_ylabel('')
659 axs[i_plt].set_xlim((0, Tplot * 1e3))
660 axs[i_plt].set_ylim((0, nc))
662 # Init title
663 title_plt = f't0 = {int(t0 / 60)} min'
665 if detect_badch:
666 # Detect and remove bad channels prior to spike detection
667 labels, xfeats = voltage.detect_bad_channels(raw, fs)
668 idx_badchan = np.where(labels != 0)[0]
669 # Plot bad channels on raw data
670 x, y = np.meshgrid(idx_badchan, np.linspace(0, Tplot * 1e3, 20))
671 axs[i_plt].plot(y.flatten(), x.flatten(), '.k', markersize=1)
672 # Append title
673 title_plt += f', n={len(idx_badchan)} bad ch'
675 # Set title
676 axs[i_plt].title.set_text(title_plt)
678 else:
679 axs[i_plt].title.set_text(f'CANNOT DESTRIPE, N CHAN = {nc}')
681 # Amend some axis style
682 if i_plt > 0:
683 axs[i_plt].set_yticklabels('')
685 # Fig layout
686 fig.tight_layout()
687 if savedir is not None:
688 fig.savefig(fname=savedir)
690 return fig, axs
693def dlc_qc_plot(session_path, one=None, device_collection='raw_video_data',
694 cameras=('left', 'right', 'body'), trials_collection='alf'):
695 """
696 Creates DLC QC plot.
697 Data is searched first locally, then on Alyx. Panels that lack required data are skipped.
699 Required data to create all panels
700 'raw_video_data/_iblrig_bodyCamera.raw.mp4',
701 'raw_video_data/_iblrig_leftCamera.raw.mp4',
702 'raw_video_data/_iblrig_rightCamera.raw.mp4',
703 'alf/_ibl_bodyCamera.dlc.pqt',
704 'alf/_ibl_leftCamera.dlc.pqt',
705 'alf/_ibl_rightCamera.dlc.pqt',
706 'alf/_ibl_bodyCamera.times.npy',
707 'alf/_ibl_leftCamera.times.npy',
708 'alf/_ibl_rightCamera.times.npy',
709 'alf/_ibl_leftCamera.features.pqt',
710 'alf/_ibl_rightCamera.features.pqt',
711 'alf/rightROIMotionEnergy.position.npy',
712 'alf/leftROIMotionEnergy.position.npy',
713 'alf/bodyROIMotionEnergy.position.npy',
714 'alf/_ibl_trials.choice.npy',
715 'alf/_ibl_trials.feedbackType.npy',
716 'alf/_ibl_trials.feedback_times.npy',
717 'alf/_ibl_trials.stimOn_times.npy',
718 'alf/_ibl_wheel.position.npy',
719 'alf/_ibl_wheel.timestamps.npy',
720 'alf/licks.times.npy',
722 :params session_path: Path to session data on disk
723 :params one: ONE instance, if None is given, default ONE is instantiated
724 :returns: Matplotlib figure
725 """
727 one = one or ONE() 1b
728 # hack for running on cortexlab local server
729 if one.alyx.base_url == 'https://alyx.cortexlab.net': 1b
730 one = ONE(base_url='https://alyx.internationalbrainlab.org')
731 data = {} 1b
732 session_path = Path(session_path) 1b
734 # Load data for each camera
735 for cam in cameras: 1b
736 # Load a single frame for each video
737 # Check if video data is available locally,if yes, load a single frame
738 video_path = session_path.joinpath(device_collection, f'_iblrig_{cam}Camera.raw.mp4') 1b
739 if video_path.exists(): 1b
740 data[f'{cam}_frame'] = get_video_frame(video_path, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0]
741 # If not, try to stream a frame (try three times)
742 else:
743 try: 1b
744 video_url = url_from_eid(one.path2eid(session_path), one=one)[cam] 1b
745 for tries in range(3):
746 try:
747 data[f'{cam}_frame'] = get_video_frame(video_url, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0]
748 break
749 except Exception:
750 if tries < 2:
751 tries += 1
752 logger.info(f"Streaming {cam} video failed, retrying x{tries}")
753 time.sleep(30)
754 else:
755 logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.")
756 data[f'{cam}_frame'] = None
757 except KeyError: 1b
758 logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.") 1b
759 data[f'{cam}_frame'] = None 1b
760 # Other camera associated data
761 for feat in ['dlc', 'times', 'features', 'ROIMotionEnergy']: 1b
762 # Check locally first, then try to load from alyx, if nothing works, set to None
763 if feat == 'features' and cam == 'body': # this doesn't exist for body cam 1b
764 continue 1b
765 local_file = list(session_path.joinpath('alf').glob(f'*{cam}Camera.{feat}*')) 1b
766 if len(local_file) > 0: 1b
767 data[f'{cam}_{feat}'] = alfio.load_file_content(local_file[0])
768 else:
769 alyx_ds = [ds for ds in one.list_datasets(one.path2eid(session_path)) if f'{cam}Camera.{feat}' in ds] 1b
770 if len(alyx_ds) > 0: 1b
771 data[f'{cam}_{feat}'] = one.load_dataset(one.path2eid(session_path), alyx_ds[0])
772 else:
773 logger.warning(f"Could not load _ibl_{cam}Camera.{feat} some plots have to be skipped.") 1b
774 data[f'{cam}_{feat}'] = None 1b
775 # Sometimes there is a file but the object is empty, set to None
776 if data[f'{cam}_{feat}'] is not None and len(data[f'{cam}_{feat}']) == 0: 1b
777 logger.warning(f"Object loaded from _ibl_{cam}Camera.{feat} is empty, some plots have to be skipped.")
778 data[f'{cam}_{feat}'] = None
780 # If we have no frame and/or no DLC and/or no times for all cams, raise an error, something is really wrong
781 assert any(data[f'{cam}_frame'] is not None for cam in cameras), "No camera data could be loaded, aborting." 1b
782 assert any(data[f'{cam}_dlc'] is not None for cam in cameras), "No DLC data could be loaded, aborting."
783 assert any(data[f'{cam}_times'] is not None for cam in cameras), "No camera times data could be loaded, aborting."
785 # Load session level data
786 for alf_object, collection in zip(['trials', 'wheel', 'licks'], [trials_collection, trials_collection, 'alf']):
787 try:
788 data[f'{alf_object}'] = alfio.load_object(session_path.joinpath(collection), alf_object) # load locally
789 continue
790 except ALFObjectNotFound:
791 pass
792 try:
793 # then try from alyx
794 data[f'{alf_object}'] = one.load_object(one.path2eid(session_path), alf_object, collection=collection)
795 except ALFObjectNotFound:
796 logger.warning(f"Could not load {alf_object} object, some plots have to be skipped.")
797 data[f'{alf_object}'] = None
799 # Simplify and clean up trials data
800 if data['trials']:
801 data['trials'] = pd.DataFrame(
802 {k: data['trials'][k] for k in ['stimOn_times', 'feedback_times', 'choice', 'feedbackType']})
803 # Discard nan events and too long trials
804 data['trials'] = data['trials'].dropna()
805 data['trials'] = data['trials'].drop(
806 data['trials'][(data['trials']['feedback_times'] - data['trials']['stimOn_times']) > 10].index)
808 # Make a list of panels, if inputs are missing, instead input a text to display
809 panels = []
810 # Panel A, B, C: Trace on frame
811 for cam in cameras:
812 if data[f'{cam}_frame'] is not None and data[f'{cam}_dlc'] is not None:
813 panels.append((plot_trace_on_frame,
814 {'frame': data[f'{cam}_frame'], 'dlc_df': data[f'{cam}_dlc'], 'cam': cam}))
815 else:
816 panels.append((None, f'Data missing\n{cam.capitalize()} cam trace on frame'))
818 # If trials data is not there, we cannot plot any of the trial average plots, skip all remaining panels
819 if data['trials'] is None:
820 panels.extend([(None, 'No trial data,\ncannot compute trial avgs')] * 7)
821 else:
822 # Panel D: Motion energy
823 camera_dict = {}
824 for cam in cameras: # Remove cameras where we don't have motion energy AND camera times
825 d = {'motion_energy': data.get(f'{cam}_ROIMotionEnergy'), 'times': data.get(f'{cam}_times')}
826 if not any(x is None for x in d.values()):
827 camera_dict[cam] = d
828 if len(camera_dict) > 0:
829 panels.append((plot_motion_energy_hist, {'camera_dict': camera_dict, 'trials_df': data['trials']}))
830 else:
831 panels.append((None, 'Data missing\nMotion energy'))
833 # Panel E: Wheel position
834 if data['wheel']:
835 panels.append((plot_wheel_position, {'wheel_position': data['wheel'].position,
836 'wheel_time': data['wheel'].timestamps,
837 'trials_df': data['trials']}))
838 else:
839 panels.append((None, 'Data missing\nWheel position'))
841 # Panel F, G: Paw speed and nose speed
842 # Try if all data is there for left cam first, otherwise right
843 for cam in ['left', 'right']:
844 fail = False
845 if (data[f'{cam}_dlc'] is not None and data[f'{cam}_times'] is not None
846 and len(data[f'{cam}_times']) >= len(data[f'{cam}_dlc'])):
847 break
848 fail = True
849 if not fail:
850 paw = 'r' if cam == 'left' else 'l'
851 panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'],
852 'trials_df': data['trials'], 'feature': f'paw_{paw}', 'cam': cam}))
853 panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'],
854 'trials_df': data['trials'], 'feature': 'nose_tip', 'legend': False,
855 'cam': cam}))
856 else:
857 panels.extend([(None, 'Data missing or corrupt\nSpeed histograms')] * 2)
859 # Panel H and I: Lick plots
860 if data['licks'] and data['licks'].times.shape[0] > 0:
861 panels.append((plot_lick_hist, {'lick_times': data['licks'].times, 'trials_df': data['trials']}))
862 panels.append((plot_lick_raster, {'lick_times': data['licks'].times, 'trials_df': data['trials']}))
863 else:
864 panels.extend([(None, 'Data missing\nLicks plots') for i in range(2)])
866 # Panel J: pupil plot
867 # Try if all data is there for left cam first, otherwise right
868 for cam in ['left', 'right']:
869 fail = False
870 if (data.get(f'{cam}_times') is not None and data.get(f'{cam}_features') is not None
871 and len(data[f'{cam}_times']) >= len(data[f'{cam}_features'])
872 and not np.all(np.isnan(data[f'{cam}_features'].pupilDiameter_smooth))):
873 break
874 fail = True
875 if not fail:
876 panels.append((plot_pupil_diameter_hist,
877 {'pupil_diameter': data[f'{cam}_features'].pupilDiameter_smooth,
878 'cam_times': data[f'{cam}_times'], 'trials_df': data['trials'], 'cam': cam}))
879 else:
880 panels.append((None, 'Data missing or corrupt\nPupil diameter'))
882 # Plotting
883 plt.rcParams.update({'font.size': 10})
884 fig = plt.figure(figsize=(17, 10))
885 for i, panel in enumerate(panels):
886 ax = plt.subplot(2, 5, i + 1)
887 ax.text(-0.1, 1.15, ascii_uppercase[i], transform=ax.transAxes, fontsize=16, fontweight='bold')
888 # Check if there was in issue with inputs, if yes, print the respective text
889 if panel[0] is None:
890 ax.text(.5, .5, panel[1], color='r', fontweight='bold', fontsize=12, horizontalalignment='center',
891 verticalalignment='center', transform=ax.transAxes)
892 plt.axis('off')
893 else:
894 try:
895 panel[0](**panel[1])
896 except Exception:
897 logger.error(f'Error in {panel[0].__name__}\n' + traceback.format_exc())
898 ax.text(.5, .5, f'Error while plotting\n{panel[0].__name__}', color='r', fontweight='bold',
899 fontsize=12, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
900 plt.axis('off')
901 plt.tight_layout(rect=[0, 0.03, 1, 0.95])
903 return fig