Coverage for ibllib/plots/figures.py: 15%
520 statements
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
« prev ^ index » next coverage.py v7.7.0, created at 2025-03-17 15:25 +0000
1"""
2Module that produces figures, usually for the extraction pipeline
3"""
4import logging
5import time
6from pathlib import Path
7import traceback
8from string import ascii_uppercase
10import numpy as np
11import pandas as pd
12import scipy.signal
13import matplotlib.pyplot as plt
15from ibldsp import voltage
16from ibllib.plots.snapshot import ReportSnapshotProbe, ReportSnapshot
17from one.api import ONE
18import one.alf.io as alfio
19from one.alf.exceptions import ALFObjectNotFound
20from ibllib.io.video import get_video_frame, url_from_eid
21from ibllib.oneibl.data_handlers import ExpectedDataset
22import spikeglx
23import neuropixel
24from brainbox.plot import driftmap
25from brainbox.io.spikeglx import Streamer
26from brainbox.behavior.dlc import SAMPLING, plot_trace_on_frame, plot_wheel_position, plot_lick_hist, \
27 plot_lick_raster, plot_motion_energy_hist, plot_speed_hist, plot_pupil_diameter_hist
28from brainbox.ephys_plots import image_lfp_spectrum_plot, image_rms_plot, plot_brain_regions
29from brainbox.io.one import load_spike_sorting_fast
30from brainbox.behavior import training
31from iblutil.numerical import ismember
32from ibllib.plots.misc import Density
35logger = logging.getLogger(__name__)
38def set_axis_label_size(ax, labels=14, ticklabels=12, title=14, cmap=False):
39 """
40 Function to normalise size of all axis labels
41 :param ax:
42 :param labels:
43 :param ticklabels:
44 :param title:
45 :param cmap:
46 :return:
47 """
49 ax.xaxis.get_label().set_fontsize(labels)
50 ax.yaxis.get_label().set_fontsize(labels)
51 ax.tick_params(labelsize=ticklabels)
52 ax.title.set_fontsize(title)
54 if cmap:
55 cbar = ax.images[-1].colorbar
56 cbar.ax.tick_params(labelsize=ticklabels)
57 cbar.ax.yaxis.get_label().set_fontsize(labels)
60def remove_axis_outline(ax):
61 """
62 Function to remove outline of empty axis
63 :param ax:
64 :return:
65 """
66 ax.get_xaxis().set_visible(False)
67 ax.get_yaxis().set_visible(False)
68 ax.spines['right'].set_visible(False)
69 ax.spines['top'].set_visible(False)
70 ax.spines['bottom'].set_visible(False)
71 ax.spines['left'].set_visible(False)
74class BehaviourPlots(ReportSnapshot):
75 """Behavioural plots."""
77 @property
78 def signature(self):
79 signature = {
80 'input_files': [
81 ('*trials.table.pqt', self.trials_collection, True),
82 ],
83 'output_files': [
84 ('psychometric_curve.png', 'snapshot/behaviour', True),
85 ('chronometric_curve.png', 'snapshot/behaviour', True),
86 ('reaction_time_with_trials.png', 'snapshot/behaviour', True)
87 ]
88 }
89 return signature
91 def __init__(self, eid, session_path=None, one=None, **kwargs):
92 """
93 Generate and upload behaviour plots.
95 Parameters
96 ----------
97 eid : str, uuid.UUID
98 An experiment UUID.
99 session_path : pathlib.Path
100 A session path.
101 one : one.api.One
102 An instance of ONE for registration to Alyx.
103 trials_collection : str
104 The location of the trials data (default: 'alf').
105 kwargs
106 Arguments for ReportSnapshot constructor.
107 """
108 self.one = one
109 self.eid = eid
110 self.session_path = session_path or self.one.eid2path(self.eid)
111 self.trials_collection = kwargs.pop('task_collection', 'alf')
112 super(BehaviourPlots, self).__init__(self.session_path, self.eid, one=self.one,
113 **kwargs)
114 # Output directory should mirror trials collection, sans 'alf' part
115 self.output_directory = self.session_path.joinpath(
116 'snapshot', 'behaviour', self.trials_collection.removeprefix('alf').strip('/'))
117 self.output_directory.mkdir(exist_ok=True, parents=True)
119 def _run(self):
121 output_files = []
122 trials = alfio.load_object(self.session_path.joinpath(self.trials_collection), 'trials')
123 if self.one:
124 title = self.one.path2ref(self.session_path, as_dict=False)
125 else:
126 title = '_'.join(list(self.session_path.parts[-3:]))
128 fig, ax = training.plot_psychometric(trials, title=title, figsize=(8, 6))
129 set_axis_label_size(ax)
130 save_path = Path(self.output_directory).joinpath("psychometric_curve.png")
131 output_files.append(save_path)
132 fig.savefig(save_path)
133 plt.close(fig)
135 fig, ax = training.plot_reaction_time(trials, title=title, figsize=(8, 6))
136 set_axis_label_size(ax)
137 save_path = Path(self.output_directory).joinpath("chronometric_curve.png")
138 output_files.append(save_path)
139 fig.savefig(save_path)
140 plt.close(fig)
142 fig, ax = training.plot_reaction_time_over_trials(trials, title=title, figsize=(8, 6))
143 set_axis_label_size(ax)
144 save_path = Path(self.output_directory).joinpath("reaction_time_with_trials.png")
145 output_files.append(save_path)
146 fig.savefig(save_path)
147 plt.close(fig)
149 return output_files
152# TODO put into histology and alignment pipeline
153class HistologySlices(ReportSnapshotProbe):
154 """Plots coronal and sagittal slice showing electrode locations."""
156 def _run(self):
158 assert self.pid
159 assert self.brain_atlas
161 output_files = []
162 self.histology_status = self.get_histology_status()
163 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
165 if self.hist_lookup[self.histology_status] > 0:
166 fig = plt.figure(figsize=(12, 9))
167 gs = fig.add_gridspec(2, 2, width_ratios=[.95, .05])
168 ax1 = fig.add_subplot(gs[0, 0])
169 self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 1, ax=ax1)
170 ax1.scatter(electrodes['mlapdv'][:, 0] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r')
171 ax1.set_title(f"{self.pid_label}")
173 ax2 = fig.add_subplot(gs[1, 0])
174 self.brain_atlas.plot_tilted_slice(electrodes['mlapdv'], 0, ax=ax2)
175 ax2.scatter(electrodes['mlapdv'][:, 1] * 1e6, electrodes['mlapdv'][:, 2] * 1e6, s=8, c='r')
177 ax3 = fig.add_subplot(gs[:, 1])
178 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=ax3,
179 title=self.histology_status)
181 save_path = Path(self.output_directory).joinpath("histology_slices.png")
182 output_files.append(save_path)
183 fig.savefig(save_path)
184 plt.close(fig)
186 return output_files
188 def get_probe_signature(self):
189 input_signature = [('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False),
190 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False),
191 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)]
192 output_signature = [('histology_slices.png', f'snapshot/{self.pname}', True)]
193 self.signature = {'input_files': input_signature, 'output_files': output_signature}
196class LfpPlots(ReportSnapshotProbe):
197 """
198 Plots LFP spectrum and LFP RMS plots
199 """
201 def _run(self):
203 assert self.pid
205 output_files = []
207 if self.location != 'server':
208 self.histology_status = self.get_histology_status()
209 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
211 # lfp spectrum
212 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
213 lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysSpectralDensityLF',
214 namespace='iblqc')
215 _, _, _ = image_lfp_spectrum_plot(lfp.power, lfp.freqs, clim=[-65, -95], fig_kwargs={'figsize': (8, 6)}, ax=axs[0],
216 display=True, title=f"{self.pid_label}")
217 set_axis_label_size(axs[0], cmap=True)
218 if self.histology_status:
219 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1],
220 title=self.histology_status)
221 set_axis_label_size(axs[1])
222 else:
223 remove_axis_outline(axs[1])
225 save_path = Path(self.output_directory).joinpath("lfp_spectrum.png")
226 output_files.append(save_path)
227 fig.savefig(save_path)
228 plt.close(fig)
230 # lfp rms
231 # TODO need to figure out the clim range
232 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
233 lfp = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsLF', namespace='iblqc')
234 _, _, _ = image_rms_plot(lfp.rms, lfp.timestamps, median_subtract=False, band='LFP', clim=[-35, -45], ax=axs[0],
235 cmap='inferno', fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}")
236 set_axis_label_size(axs[0], cmap=True)
237 if self.histology_status:
238 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1],
239 title=self.histology_status)
240 set_axis_label_size(axs[1])
241 else:
242 remove_axis_outline(axs[1])
244 save_path = Path(self.output_directory).joinpath("lfp_rms.png")
245 output_files.append(save_path)
246 fig.savefig(save_path)
247 plt.close(fig)
249 return output_files
251 def get_probe_signature(self):
252 input_signature = [('_iblqc_ephysTimeRmsLF.rms.npy', f'raw_ephys_data/{self.pname}', True),
253 ('_iblqc_ephysTimeRmsLF.timestamps.npy', f'raw_ephys_data/{self.pname}', True),
254 ('_iblqc_ephysSpectralDensityLF.freqs.npy', f'raw_ephys_data/{self.pname}', True),
255 ('_iblqc_ephysSpectralDensityLF.power.npy', f'raw_ephys_data/{self.pname}', True),
256 ('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False),
257 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False),
258 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)]
259 output_signature = [('lfp_spectrum.png', f'snapshot/{self.pname}', True),
260 ('lfp_rms.png', f'snapshot/{self.pname}', True)]
261 self.signature = {'input_files': input_signature, 'output_files': output_signature}
264class ApPlots(ReportSnapshotProbe):
265 """
266 Plots AP RMS plots
267 """
269 def _run(self):
271 assert self.pid
273 output_files = []
275 if self.location != 'server':
276 self.histology_status = self.get_histology_status()
277 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
279 # TODO need to figure out the clim range
280 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
281 ap = alfio.load_object(self.session_path.joinpath(f'raw_ephys_data/{self.pname}'), 'ephysTimeRmsAP', namespace='iblqc')
282 _, _, _ = image_rms_plot(ap.rms, ap.timestamps, median_subtract=False, band='AP', ax=axs[0],
283 fig_kwargs={'figsize': (8, 6)}, display=True, title=f"{self.pid_label}")
284 set_axis_label_size(axs[0], cmap=True)
285 if self.histology_status:
286 plot_brain_regions(electrodes['atlas_id'], brain_regions=self.brain_regions, display=True, ax=axs[1],
287 title=self.histology_status)
288 set_axis_label_size(axs[1])
289 else:
290 remove_axis_outline(axs[1])
292 save_path = Path(self.output_directory).joinpath("ap_rms.png")
293 output_files.append(save_path)
294 fig.savefig(save_path)
295 plt.close(fig)
297 return output_files
299 def get_probe_signature(self):
300 input_signature = [('_iblqc_ephysTimeRmsAP.rms.npy', f'raw_ephys_data/{self.pname}', True),
301 ('_iblqc_ephysTimeRmsAP.timestamps.npy', f'raw_ephys_data/{self.pname}', True),
302 ('electrodeSites.localCoordinates.npy', f'alf/{self.pname}', False),
303 ('electrodeSites.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}', False),
304 ('electrodeSites.mlapdv.npy', f'alf/{self.pname}', False)]
305 output_signature = [('ap_rms.png', f'snapshot/{self.pname}', True)]
306 self.signature = {'input_files': input_signature, 'output_files': output_signature}
309class SpikeSorting(ReportSnapshotProbe):
310 """
311 Plots raw electrophysiology AP band
312 :param session_path: session path
313 :param probe_id: str, UUID of the probe insertion for which to create the plot
314 :param **kwargs: keyword arguments passed to tasks.Task
315 """
317 def _run(self, collection=None):
318 """runs for initiated PID, streams data, destripe and check bad channels"""
320 def plot_driftmap(self, spikes, clusters, channels, collection):
321 fig, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
322 driftmap(spikes.times, spikes.depths, t_bin=0.007, d_bin=10, vmax=0.5, ax=axs[0])
323 title_str = f"{self.pid_label}, {collection}, {self.pid} \n " \
324 f"{spikes.clusters.size:_} spikes, {clusters.depths.size:_} clusters"
325 ylim = (0, np.max(channels['axial_um']))
326 axs[0].set(ylim=ylim, title=title_str)
327 run_label = str(Path(collection).relative_to(f'alf/{self.pname}'))
328 run_label = "ks2matlab" if run_label == '.' else run_label
329 outfile = self.output_directory.joinpath(f"spike_sorting_raster_{run_label}.png")
330 set_axis_label_size(axs[0])
332 if self.histology_status:
333 plot_brain_regions(channels['atlas_id'], channel_depths=channels['axial_um'],
334 brain_regions=self.brain_regions, display=True, ax=axs[1], title=self.histology_status)
335 axs[1].set(ylim=ylim)
336 set_axis_label_size(axs[1])
337 else:
338 remove_axis_outline(axs[1])
340 fig.savefig(outfile)
341 plt.close(fig)
343 return outfile, fig, axs
345 output_files = []
346 if self.location == 'server':
347 assert collection
348 spikes = alfio.load_object(self.session_path.joinpath(collection), 'spikes')
349 clusters = alfio.load_object(self.session_path.joinpath(collection), 'clusters')
350 channels = alfio.load_object(self.session_path.joinpath(collection), 'channels')
351 channels['axial_um'] = channels['localCoordinates'][:, 1]
353 out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection)
354 output_files.append(out)
356 else:
357 self.histology_status = self.get_histology_status()
358 all_here, output_files = self.assert_expected(self.output_files, silent=True)
359 spike_sorting_runs = self.one.list_datasets(self.eid, filename='spikes.times.npy', collection=f'alf/{self.pname}*')
360 if all_here and len(output_files) == len(spike_sorting_runs):
361 return output_files
362 logger.info(self.output_directory)
363 for run in spike_sorting_runs:
364 collection = str(Path(run).parent.as_posix())
365 spikes, clusters, channels = load_spike_sorting_fast(
366 eid=self.eid, probe=self.pname, one=self.one, nested=False, collection=collection,
367 dataset_types=['spikes.depths'], brain_regions=self.brain_regions)
369 if 'atlas_id' not in channels.keys():
370 channels = self.get_channels('channels', collection)
372 out, fig, axs = plot_driftmap(self, spikes, clusters, channels, collection)
373 output_files.append(out)
375 return output_files
377 def get_probe_signature(self):
378 input_signature = [('spikes.times.npy', f'alf/{self.pname}*', True),
379 ('spikes.amps.npy', f'alf/{self.pname}*', True),
380 ('spikes.depths.npy', f'alf/{self.pname}*', True),
381 ('clusters.depths.npy', f'alf/{self.pname}*', True),
382 ('channels.localCoordinates.npy', f'alf/{self.pname}*', False),
383 ('channels.mlapdv.npy', f'alf/{self.pname}*', False),
384 ('channels.brainLocationIds_ccf_2017.npy', f'alf/{self.pname}*', False)]
385 output_signature = [('spike_sorting_raster*.png', f'snapshot/{self.pname}', True)]
386 self.signature = {'input_files': input_signature, 'output_files': output_signature}
388 def get_signatures(self, **kwargs):
389 files_spikes = Path(self.session_path).joinpath('alf').rglob('spikes.times.npy')
390 folder_probes = [f.parent for f in files_spikes]
391 full_input_files = []
392 for sig in self.signature['input_files']:
393 for folder in folder_probes:
394 full_input_files.append((sig[0], str(folder.relative_to(self.session_path)), sig[2]))
395 if len(full_input_files) != 0:
396 self.input_files = full_input_files
397 else:
398 self.input_files = self.signature['input_files']
399 self.output_files = self.signature['output_files']
400 self.input_files = [ExpectedDataset.input(*i) for i in self.input_files]
401 self.output_files = [ExpectedDataset.output(*i) for i in self.output_files]
404class BadChannelsAp(ReportSnapshotProbe):
405 """
406 Plots raw electrophysiology AP band
407 task = BadChannelsAp(pid, one=one=one)
408 :param session_path: session path
409 :param probe_id: str, UUID of the probe insertion for which to create the plot
410 :param **kwargs: keyword arguments passed to tasks.Task
411 """
413 def get_probe_signature(self):
414 pname = self.pname
415 input_signature = [('*ap.meta', f'raw_ephys_data/{pname}', True),
416 ('*ap.ch', f'raw_ephys_data/{pname}', False)]
417 output_signature = [('raw_ephys_bad_channels.png', f'snapshot/{pname}', True),
418 ('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True),
419 ('raw_ephys_bad_channels_highpass.png', f'snapshot/{pname}', True),
420 ('raw_ephys_bad_channels_destripe.png', f'snapshot/{pname}', True),
421 ('raw_ephys_bad_channels_difference.png', f'snapshot/{pname}', True),
422 ]
423 self.signature = {'input_files': input_signature, 'output_files': output_signature}
425 def _run(self):
426 """runs for initiated PID, streams data, destripe and check bad channels"""
427 assert self.pid
428 self.eqcs = []
429 T0 = 60 * 30
430 SNAPSHOT_LABEL = "raw_ephys_bad_channels"
431 output_files = list(self.output_directory.glob(f'{SNAPSHOT_LABEL}*'))
432 if len(output_files) == 4:
433 return output_files
435 self.output_directory.mkdir(exist_ok=True, parents=True)
437 if self.location != 'server':
438 self.histology_status = self.get_histology_status()
439 electrodes = self.get_channels('electrodeSites', f'alf/{self.pname}')
441 if 'atlas_id' in electrodes.keys():
442 electrodes['ibr'] = ismember(electrodes['atlas_id'], self.brain_regions.id)[1]
443 electrodes['acronym'] = self.brain_regions.acronym[electrodes['ibr']]
444 electrodes['name'] = self.brain_regions.name[electrodes['ibr']]
445 electrodes['title'] = self.histology_status
446 else:
447 electrodes = None
449 nsecs = 1
450 sr = Streamer(pid=self.pid, one=self.one, remove_cached=False, typ='ap')
451 s0 = T0 * sr.fs
452 tsel = slice(int(s0), int(s0) + int(nsecs * sr.fs))
453 # Important: remove sync channel from raw data, and transpose
454 raw = sr[tsel, :-sr.nsync].T
456 else:
457 electrodes = None
458 ap_file = next(self.session_path.joinpath('raw_ephys_data', self.pname).glob('*ap.*bin'), None)
459 if ap_file is not None:
460 sr = spikeglx.Reader(ap_file)
461 # If T0 is greater than recording length, take 500 sec before end
462 if sr.rl < T0:
463 T0 = int(sr.rl - 500)
464 raw = sr[int((sr.fs * T0)):int((sr.fs * (T0 + 1))), :-sr.nsync].T
465 else:
466 return []
468 if sr.meta.get('NP2.4_shank', None) is not None:
469 h = neuropixel.trace_header(sr.major_version, nshank=4)
470 h = neuropixel.split_trace_header(h, shank=int(sr.meta.get('NP2.4_shank')))
471 else:
472 h = neuropixel.trace_header(sr.major_version, nshank=np.unique(sr.geometry['shank']).size)
474 channel_labels, channel_features = voltage.detect_bad_channels(raw, sr.fs)
475 _, eqcs, output_files = ephys_bad_channels(
476 raw=raw, fs=sr.fs, channel_labels=channel_labels, channel_features=channel_features, h=h, channels=electrodes,
477 title=SNAPSHOT_LABEL, destripe=True, save_dir=self.output_directory, br=self.brain_regions, pid_info=self.pid_label)
478 self.eqcs = eqcs
479 return output_files
482def ephys_bad_channels(raw, fs, channel_labels, channel_features, h=None, channels=None, title="ephys_bad_channels",
483 save_dir=None, destripe=False, eqcs=None, br=None, pid_info=None, plot_backend='matplotlib'):
484 nc, ns = raw.shape
485 rl = ns / fs
487 def gain2level(gain):
488 return 10 ** (gain / 20) * 4 * np.array([-1, 1])
490 if fs >= 2600: # AP band
491 ylim_rms = [0, 100]
492 ylim_psd_hf = [0, 0.1]
493 eqc_xrange = [450, 500]
494 butter_kwargs = {'N': 3, 'Wn': 300 / fs * 2, 'btype': 'highpass'}
495 eqc_gain = - 90
496 eqc_levels = gain2level(eqc_gain)
497 else:
498 # we are working with the LFP
499 ylim_rms = [0, 1000]
500 ylim_psd_hf = [0, 1]
501 eqc_xrange = [450, 950]
502 butter_kwargs = {'N': 3, 'Wn': np.array([2, 125]) / fs * 2, 'btype': 'bandpass'}
503 eqc_gain = - 78
504 eqc_levels = gain2level(eqc_gain)
506 inoisy = np.where(channel_labels == 2)[0]
507 idead = np.where(channel_labels == 1)[0]
508 ioutside = np.where(channel_labels == 3)[0]
510 # display voltage traces
511 eqcs = [] if eqcs is None else eqcs
512 # butterworth, for display only
513 sos = scipy.signal.butter(**butter_kwargs, output='sos')
514 butt = scipy.signal.sosfiltfilt(sos, raw)
516 if plot_backend == 'matplotlib':
517 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
518 eqcs.append(Density(butt, fs=fs, taxis=1, ax=axs[0], title='highpass', vmin=eqc_levels[0], vmax=eqc_levels[1]))
520 if destripe:
521 dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels)
522 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
523 eqcs.append(Density(
524 dest, fs=fs, taxis=1, ax=axs[0], title='destripe', vmin=eqc_levels[0], vmax=eqc_levels[1]))
525 _, axs = plt.subplots(1, 2, gridspec_kw={'width_ratios': [.95, .05]}, figsize=(16, 9))
526 eqcs.append(Density((butt - dest), fs=fs, taxis=1, ax=axs[0], title='difference', vmin=eqc_levels[0],
527 vmax=eqc_levels[1]))
529 for eqc in eqcs:
530 y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500))
531 eqc.ax.scatter(x.flatten(), y.flatten(), c='goldenrod', s=4)
532 y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500))
533 eqc.ax.scatter(x.flatten(), y.flatten(), c='r', s=4)
534 y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500))
535 eqc.ax.scatter(x.flatten(), y.flatten(), c='b', s=4)
537 eqc.ax.set_xlim(*eqc_xrange)
538 eqc.ax.set_ylim(0, nc)
539 eqc.ax.set_ylabel('Channel index')
540 eqc.ax.set_title(f'{pid_info}_{eqc.title}')
541 set_axis_label_size(eqc.ax)
543 ax = eqc.figure.axes[1]
544 if channels is not None:
545 chn_title = channels.get('title', None)
546 plot_brain_regions(channels['atlas_id'], brain_regions=br, display=True, ax=ax,
547 title=chn_title)
548 set_axis_label_size(ax)
549 else:
550 remove_axis_outline(ax)
551 else:
552 from viewspikes.gui import viewephys # noqa
553 eqcs.append(viewephys(butt, fs=fs, channels=channels, title='highpass', br=br))
555 if destripe:
556 dest = voltage.destripe(raw, fs=fs, h=h, channel_labels=channel_labels)
557 eqcs.append(viewephys(dest, fs=fs, channels=channels, title='destripe', br=br))
558 eqcs.append(viewephys((butt - dest), fs=fs, channels=channels, title='difference', br=br))
560 for eqc in eqcs:
561 y, x = np.meshgrid(ioutside, np.linspace(0, rl * 1e3, 500))
562 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(164, 142, 35), label='outside')
563 y, x = np.meshgrid(inoisy, np.linspace(0, rl * 1e3, 500))
564 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(255, 0, 0), label='noisy')
565 y, x = np.meshgrid(idead, np.linspace(0, rl * 1e3, 500))
566 eqc.ctrl.add_scatter(x.flatten(), y.flatten(), rgb=(0, 0, 255), label='dead')
568 eqcs[0].ctrl.set_gain(eqc_gain)
569 eqcs[0].resize(1960, 1200)
570 eqcs[0].viewBox_seismic.setXRange(*eqc_xrange)
571 eqcs[0].viewBox_seismic.setYRange(0, nc)
572 eqcs[0].ctrl.propagate()
574 # display features
575 fig, axs = plt.subplots(2, 2, sharex=True, figsize=[16, 9], tight_layout=True)
576 fig.suptitle(title)
577 axs[0, 0].plot(channel_features['rms_raw'] * 1e6)
578 axs[0, 0].set(title='rms', xlabel='channel number', ylabel='rms (uV)', ylim=ylim_rms)
580 axs[1, 0].plot(channel_features['psd_hf'])
581 axs[1, 0].plot(inoisy, np.minimum(channel_features['psd_hf'][inoisy], 0.0999), 'xr')
582 axs[1, 0].set(title='PSD above 80% Nyquist', xlabel='channel number', ylabel='PSD (uV ** 2 / Hz)', ylim=ylim_psd_hf)
583 axs[1, 0].legend = ['psd', 'noisy']
585 axs[0, 1].plot(channel_features['xcor_hf'])
586 axs[0, 1].plot(channel_features['xcor_lf'])
588 axs[0, 1].plot(idead, channel_features['xcor_hf'][idead], 'xb')
589 axs[0, 1].plot(ioutside, channel_features['xcor_lf'][ioutside], 'xy')
590 axs[0, 1].set(title='Similarity', xlabel='channel number', ylabel='', ylim=[-1.5, 0.5])
591 axs[0, 1].legend(['detrend', 'trend', 'dead', 'outside'])
593 fscale, psd = scipy.signal.welch(raw * 1e6, fs=fs) # units; uV ** 2 / Hz
594 axs[1, 1].imshow(20 * np.log10(psd).T, extent=[0, nc - 1, fscale[0], fscale[-1]], origin='lower', aspect='auto',
595 vmin=-50, vmax=-20)
596 axs[1, 1].set(title='PSD', xlabel='channel number', ylabel="Frequency (Hz)")
597 axs[1, 1].plot(idead, idead * 0 + fs / 4, 'xb')
598 axs[1, 1].plot(inoisy, inoisy * 0 + fs / 4, 'xr')
599 axs[1, 1].plot(ioutside, ioutside * 0 + fs / 4, 'xy')
601 if save_dir is not None:
602 output_files = [Path(save_dir).joinpath(f"{title}.png")]
603 fig.savefig(output_files[0])
604 for eqc in eqcs:
605 if plot_backend == 'matplotlib':
606 output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.title}.png"))
607 eqc.figure.savefig(str(output_files[-1]))
608 else:
609 output_files.append(Path(save_dir).joinpath(f"{title}_{eqc.windowTitle()}.png"))
610 eqc.grab().save(str(output_files[-1]))
611 return fig, eqcs, output_files
612 else:
613 return fig, eqcs
616def raw_destripe(raw, fs, t0, i_plt, n_plt,
617 fig=None, axs=None, savedir=None, detect_badch=True,
618 SAMPLE_SKIP=200, DISPLAY_TIME=0.05, N_CHAN=384,
619 MIN_X=-0.00011, MAX_X=0.00011):
620 '''
621 :param raw: raw ephys data, Ns x Nc, x-axis: time (s), y-axis: channel
622 :param fs: sampling freq (Hz) of the raw ephys data
623 :param t0: time (s) of ephys sample beginning from session start
624 :param i_plt: increment of plot to display image one (start from 0, has to be < n_plt)
625 :param n_plt: total number of subplot on figure
626 :param fig: figure handle
627 :param axs: axis handle
628 :param savedir: filename, including directory, to save figure to
629 :param detect_badch: boolean, to detect or not bad channels
630 :param SAMPLE_SKIP: number of samples to skip at origin of ephsy sample for display
631 :param DISPLAY_TIME: time (s) to display
632 :param N_CHAN: number of expected channels on the probe
633 :param MIN_X: max voltage for color range
634 :param MAX_X: min voltage for color range
635 :return: fig, axs
636 '''
638 # Import
639 from ibldsp import voltage
640 from ibllib.plots import Density
642 # Init fig
643 if fig is None or axs is None:
644 fig, axs = plt.subplots(nrows=1, ncols=n_plt, figsize=(14, 5), gridspec_kw={'width_ratios': 4 * n_plt})
646 if i_plt > len(axs) - 1: # Error
647 raise ValueError(f'The given increment of subplot ({i_plt + 1}) '
648 f'is larger than the total number of subplots ({len(axs)})')
650 [nc, ns] = raw.shape
651 if nc == N_CHAN:
652 destripe = voltage.destripe(raw, fs=fs)
653 X = destripe[:, :int(DISPLAY_TIME * fs)].T
654 Xs = X[SAMPLE_SKIP:].T # Remove artifact at beginning
655 Tplot = Xs.shape[1] / fs
657 # PLOT RAW DATA
658 d = Density(-Xs, fs=fs, taxis=1, ax=axs[i_plt], vmin=MIN_X, vmax=MAX_X) # noqa
659 axs[i_plt].set_ylabel('')
660 axs[i_plt].set_xlim((0, Tplot * 1e3))
661 axs[i_plt].set_ylim((0, nc))
663 # Init title
664 title_plt = f't0 = {int(t0 / 60)} min'
666 if detect_badch:
667 # Detect and remove bad channels prior to spike detection
668 labels, xfeats = voltage.detect_bad_channels(raw, fs)
669 idx_badchan = np.where(labels != 0)[0]
670 # Plot bad channels on raw data
671 x, y = np.meshgrid(idx_badchan, np.linspace(0, Tplot * 1e3, 20))
672 axs[i_plt].plot(y.flatten(), x.flatten(), '.k', markersize=1)
673 # Append title
674 title_plt += f', n={len(idx_badchan)} bad ch'
676 # Set title
677 axs[i_plt].title.set_text(title_plt)
679 else:
680 axs[i_plt].title.set_text(f'CANNOT DESTRIPE, N CHAN = {nc}')
682 # Amend some axis style
683 if i_plt > 0:
684 axs[i_plt].set_yticklabels('')
686 # Fig layout
687 fig.tight_layout()
688 if savedir is not None:
689 fig.savefig(fname=savedir)
691 return fig, axs
694def dlc_qc_plot(session_path, one=None, device_collection='raw_video_data',
695 cameras=('left', 'right', 'body'), trials_collection='alf'):
696 """
697 Creates DLC QC plot.
698 Data is searched first locally, then on Alyx. Panels that lack required data are skipped.
700 Required data to create all panels
701 'raw_video_data/_iblrig_bodyCamera.raw.mp4',
702 'raw_video_data/_iblrig_leftCamera.raw.mp4',
703 'raw_video_data/_iblrig_rightCamera.raw.mp4',
704 'alf/_ibl_bodyCamera.dlc.pqt',
705 'alf/_ibl_leftCamera.dlc.pqt',
706 'alf/_ibl_rightCamera.dlc.pqt',
707 'alf/_ibl_bodyCamera.times.npy',
708 'alf/_ibl_leftCamera.times.npy',
709 'alf/_ibl_rightCamera.times.npy',
710 'alf/_ibl_leftCamera.features.pqt',
711 'alf/_ibl_rightCamera.features.pqt',
712 'alf/rightROIMotionEnergy.position.npy',
713 'alf/leftROIMotionEnergy.position.npy',
714 'alf/bodyROIMotionEnergy.position.npy',
715 'alf/_ibl_trials.choice.npy',
716 'alf/_ibl_trials.feedbackType.npy',
717 'alf/_ibl_trials.feedback_times.npy',
718 'alf/_ibl_trials.stimOn_times.npy',
719 'alf/_ibl_wheel.position.npy',
720 'alf/_ibl_wheel.timestamps.npy',
721 'alf/licks.times.npy',
723 :params session_path: Path to session data on disk
724 :params one: ONE instance, if None is given, default ONE is instantiated
725 :returns: Matplotlib figure
726 """
728 one = one or ONE() 1b
729 # hack for running on cortexlab local server
730 if one.alyx.base_url == 'https://alyx.cortexlab.net': 1b
731 one = ONE(base_url='https://alyx.internationalbrainlab.org')
732 data = {} 1b
733 session_path = Path(session_path) 1b
735 # Load data for each camera
736 for cam in cameras: 1b
737 # Load a single frame for each video
738 # Check if video data is available locally,if yes, load a single frame
739 video_path = session_path.joinpath(device_collection, f'_iblrig_{cam}Camera.raw.mp4') 1b
740 if video_path.exists(): 1b
741 data[f'{cam}_frame'] = get_video_frame(video_path, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0]
742 # If not, try to stream a frame (try three times)
743 else:
744 try: 1b
745 video_url = url_from_eid(one.path2eid(session_path), one=one)[cam] 1b
746 for tries in range(3):
747 try:
748 data[f'{cam}_frame'] = get_video_frame(video_url, frame_number=5 * 60 * SAMPLING[cam])[:, :, 0]
749 break
750 except Exception:
751 if tries < 2:
752 tries += 1
753 logger.info(f"Streaming {cam} video failed, retrying x{tries}")
754 time.sleep(30)
755 else:
756 logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.")
757 data[f'{cam}_frame'] = None
758 except KeyError: 1b
759 logger.warning(f"Could not load video frame for {cam} cam. Skipping trace on frame.") 1b
760 data[f'{cam}_frame'] = None 1b
761 # Other camera associated data
762 for feat in ['dlc', 'times', 'features', 'ROIMotionEnergy']: 1b
763 # Check locally first, then try to load from alyx, if nothing works, set to None
764 if feat == 'features' and cam == 'body': # this doesn't exist for body cam 1b
765 continue 1b
766 local_file = list(session_path.joinpath('alf').glob(f'*{cam}Camera.{feat}*')) 1b
767 if len(local_file) > 0: 1b
768 data[f'{cam}_{feat}'] = alfio.load_file_content(local_file[0])
769 else:
770 alyx_ds = [ds for ds in one.list_datasets(one.path2eid(session_path)) if f'{cam}Camera.{feat}' in ds] 1b
771 if len(alyx_ds) > 0: 1b
772 data[f'{cam}_{feat}'] = one.load_dataset(one.path2eid(session_path), alyx_ds[0])
773 else:
774 logger.warning(f"Could not load _ibl_{cam}Camera.{feat} some plots have to be skipped.") 1b
775 data[f'{cam}_{feat}'] = None 1b
776 # Sometimes there is a file but the object is empty, set to None
777 if data[f'{cam}_{feat}'] is not None and len(data[f'{cam}_{feat}']) == 0: 1b
778 logger.warning(f"Object loaded from _ibl_{cam}Camera.{feat} is empty, some plots have to be skipped.")
779 data[f'{cam}_{feat}'] = None
781 # If we have no frame and/or no DLC and/or no times for all cams, raise an error, something is really wrong
782 assert any(data[f'{cam}_frame'] is not None for cam in cameras), "No camera data could be loaded, aborting." 1b
783 assert any(data[f'{cam}_dlc'] is not None for cam in cameras), "No DLC data could be loaded, aborting."
784 assert any(data[f'{cam}_times'] is not None for cam in cameras), "No camera times data could be loaded, aborting."
786 # Load session level data
787 for alf_object, collection in zip(['trials', 'wheel', 'licks'], [trials_collection, trials_collection, 'alf']):
788 try:
789 data[f'{alf_object}'] = alfio.load_object(session_path.joinpath(collection), alf_object) # load locally
790 continue
791 except ALFObjectNotFound:
792 pass
793 try:
794 # then try from alyx
795 data[f'{alf_object}'] = one.load_object(one.path2eid(session_path), alf_object, collection=collection)
796 except ALFObjectNotFound:
797 logger.warning(f"Could not load {alf_object} object, some plots have to be skipped.")
798 data[f'{alf_object}'] = None
800 # Simplify and clean up trials data
801 if data['trials']:
802 data['trials'] = pd.DataFrame(
803 {k: data['trials'][k] for k in ['stimOn_times', 'feedback_times', 'choice', 'feedbackType']})
804 # Discard nan events and too long trials
805 data['trials'] = data['trials'].dropna()
806 data['trials'] = data['trials'].drop(
807 data['trials'][(data['trials']['feedback_times'] - data['trials']['stimOn_times']) > 10].index)
809 # Make a list of panels, if inputs are missing, instead input a text to display
810 panels = []
811 # Panel A, B, C: Trace on frame
812 for cam in cameras:
813 if data[f'{cam}_frame'] is not None and data[f'{cam}_dlc'] is not None:
814 panels.append((plot_trace_on_frame,
815 {'frame': data[f'{cam}_frame'], 'dlc_df': data[f'{cam}_dlc'], 'cam': cam}))
816 else:
817 panels.append((None, f'Data missing\n{cam.capitalize()} cam trace on frame'))
819 # If trials data is not there, we cannot plot any of the trial average plots, skip all remaining panels
820 if data['trials'] is None:
821 panels.extend([(None, 'No trial data,\ncannot compute trial avgs')] * 7)
822 else:
823 # Panel D: Motion energy
824 camera_dict = {}
825 for cam in cameras: # Remove cameras where we don't have motion energy AND camera times
826 d = {'motion_energy': data.get(f'{cam}_ROIMotionEnergy'), 'times': data.get(f'{cam}_times')}
827 if not any(x is None for x in d.values()):
828 camera_dict[cam] = d
829 if len(camera_dict) > 0:
830 panels.append((plot_motion_energy_hist, {'camera_dict': camera_dict, 'trials_df': data['trials']}))
831 else:
832 panels.append((None, 'Data missing\nMotion energy'))
834 # Panel E: Wheel position
835 if data['wheel']:
836 panels.append((plot_wheel_position, {'wheel_position': data['wheel'].position,
837 'wheel_time': data['wheel'].timestamps,
838 'trials_df': data['trials']}))
839 else:
840 panels.append((None, 'Data missing\nWheel position'))
842 # Panel F, G: Paw speed and nose speed
843 # Try if all data is there for left cam first, otherwise right
844 for cam in ['left', 'right']:
845 fail = False
846 if (data[f'{cam}_dlc'] is not None and data[f'{cam}_times'] is not None
847 and len(data[f'{cam}_times']) >= len(data[f'{cam}_dlc'])):
848 break
849 fail = True
850 if not fail:
851 paw = 'r' if cam == 'left' else 'l'
852 panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'],
853 'trials_df': data['trials'], 'feature': f'paw_{paw}', 'cam': cam}))
854 panels.append((plot_speed_hist, {'dlc_df': data[f'{cam}_dlc'], 'cam_times': data[f'{cam}_times'],
855 'trials_df': data['trials'], 'feature': 'nose_tip', 'legend': False,
856 'cam': cam}))
857 else:
858 panels.extend([(None, 'Data missing or corrupt\nSpeed histograms')] * 2)
860 # Panel H and I: Lick plots
861 if data['licks'] and data['licks'].times.shape[0] > 0:
862 panels.append((plot_lick_hist, {'lick_times': data['licks'].times, 'trials_df': data['trials']}))
863 panels.append((plot_lick_raster, {'lick_times': data['licks'].times, 'trials_df': data['trials']}))
864 else:
865 panels.extend([(None, 'Data missing\nLicks plots') for i in range(2)])
867 # Panel J: pupil plot
868 # Try if all data is there for left cam first, otherwise right
869 for cam in ['left', 'right']:
870 fail = False
871 if (data.get(f'{cam}_times') is not None and data.get(f'{cam}_features') is not None
872 and len(data[f'{cam}_times']) >= len(data[f'{cam}_features'])
873 and not np.all(np.isnan(data[f'{cam}_features'].pupilDiameter_smooth))):
874 break
875 fail = True
876 if not fail:
877 panels.append((plot_pupil_diameter_hist,
878 {'pupil_diameter': data[f'{cam}_features'].pupilDiameter_smooth,
879 'cam_times': data[f'{cam}_times'], 'trials_df': data['trials'], 'cam': cam}))
880 else:
881 panels.append((None, 'Data missing or corrupt\nPupil diameter'))
883 # Plotting
884 plt.rcParams.update({'font.size': 10})
885 fig = plt.figure(figsize=(17, 10))
886 for i, panel in enumerate(panels):
887 ax = plt.subplot(2, 5, i + 1)
888 ax.text(-0.1, 1.15, ascii_uppercase[i], transform=ax.transAxes, fontsize=16, fontweight='bold')
889 # Check if there was in issue with inputs, if yes, print the respective text
890 if panel[0] is None:
891 ax.text(.5, .5, panel[1], color='r', fontweight='bold', fontsize=12, horizontalalignment='center',
892 verticalalignment='center', transform=ax.transAxes)
893 plt.axis('off')
894 else:
895 try:
896 panel[0](**panel[1])
897 except Exception:
898 logger.error(f'Error in {panel[0].__name__}\n' + traceback.format_exc())
899 ax.text(.5, .5, f'Error while plotting\n{panel[0].__name__}', color='r', fontweight='bold',
900 fontsize=12, horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
901 plt.axis('off')
902 plt.tight_layout(rect=[0, 0.03, 1, 0.95])
904 return fig