Coverage for ibllib/qc/task_qc_viewer/ViewEphysQC.py: 16%

183 statements  

« prev     ^ index     » next       coverage.py v7.7.0, created at 2025-03-17 09:55 +0000

1"""An interactive PyQT QC data frame.""" 

2 

3import logging 

4 

5from PyQt5 import QtWidgets 

6from PyQt5.QtCore import ( 

7 Qt, 

8 QModelIndex, 

9 pyqtSignal, 

10 pyqtSlot, 

11 QCoreApplication, 

12 QSettings, 

13 QSize, 

14 QPoint, 

15) 

16from PyQt5.QtGui import QPalette, QShowEvent 

17from PyQt5.QtWidgets import QMenu, QAction 

18from iblqt.core import ColoredDataFrameTableModel 

19from matplotlib.figure import Figure 

20from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT 

21import pandas as pd 

22import numpy as np 

23 

24from ibllib.misc import qt 

25 

26_logger = logging.getLogger(__name__) 

27 

28 

29class PlotCanvas(FigureCanvasQTAgg): 

30 def __init__(self, parent=None, width=5, height=4, dpi=100, wheel=None): 

31 fig = Figure(figsize=(width, height), dpi=dpi) 

32 

33 FigureCanvasQTAgg.__init__(self, fig) 

34 self.setParent(parent) 

35 

36 FigureCanvasQTAgg.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) 

37 FigureCanvasQTAgg.updateGeometry(self) 

38 if wheel: 

39 self.ax, self.ax2 = fig.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, sharex=True) 

40 else: 

41 self.ax = fig.add_subplot(111) 

42 self.draw() 

43 

44 

45class PlotWindow(QtWidgets.QWidget): 

46 def __init__(self, parent=None, wheel=None): 

47 QtWidgets.QWidget.__init__(self, parent=None) 

48 self.canvas = PlotCanvas(wheel=wheel) 

49 self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting 

50 self.vbl.addWidget(self.canvas) 

51 self.setLayout(self.vbl) 

52 self.vbl.addWidget(NavigationToolbar2QT(self.canvas, self)) 

53 

54 

55class GraphWindow(QtWidgets.QWidget): 

56 _pinnedColumns = [] 

57 

58 def __init__(self, parent=None, wheel=None): 

59 QtWidgets.QWidget.__init__(self, parent=parent) 

60 

61 self.columnPinned = pyqtSignal(int, bool) 

62 

63 # load button 

64 self.pushButtonLoad = QtWidgets.QPushButton('Select File', self) 

65 self.pushButtonLoad.clicked.connect(self.loadFile) 

66 

67 # define table model & view 

68 self.tableModel = ColoredDataFrameTableModel(self) 

69 self.tableView = QtWidgets.QTableView(self) 

70 self.tableView.setModel(self.tableModel) 

71 self.tableView.setSortingEnabled(True) 

72 self.tableView.horizontalHeader().setDefaultAlignment(Qt.AlignLeft | Qt.AlignVCenter) 

73 self.tableView.horizontalHeader().setSectionsMovable(True) 

74 self.tableView.horizontalHeader().setContextMenuPolicy(Qt.CustomContextMenu) 

75 self.tableView.horizontalHeader().customContextMenuRequested.connect(self.contextMenu) 

76 self.tableView.verticalHeader().hide() 

77 self.tableView.doubleClicked.connect(self.tv_double_clicked) 

78 

79 # define colors for highlighted cells 

80 p = self.tableView.palette() 

81 p.setColor(QPalette.Highlight, Qt.black) 

82 p.setColor(QPalette.HighlightedText, Qt.white) 

83 self.tableView.setPalette(p) 

84 

85 # QAction for pinning columns 

86 self.pinAction = QAction('Pin column', self) 

87 self.pinAction.setCheckable(True) 

88 self.pinAction.toggled.connect(self.pinColumn) 

89 

90 # Filter columns by name 

91 self.lineEditFilter = QtWidgets.QLineEdit(self) 

92 self.lineEditFilter.setPlaceholderText('Filter columns') 

93 self.lineEditFilter.textChanged.connect(self.changeFilter) 

94 self.lineEditFilter.setMinimumWidth(200) 

95 

96 # colormap picker 

97 self.comboboxColormap = QtWidgets.QComboBox(self) 

98 colormaps = {self.tableModel.colormap, 'inferno', 'magma', 'plasma', 'summer'} 

99 self.comboboxColormap.addItems(sorted(list(colormaps))) 

100 self.comboboxColormap.setCurrentText(self.tableModel.colormap) 

101 self.comboboxColormap.currentTextChanged.connect(self.tableModel.setColormap) 

102 

103 # slider for alpha values 

104 self.sliderAlpha = QtWidgets.QSlider(Qt.Horizontal, self) 

105 self.sliderAlpha.setMaximumWidth(100) 

106 self.sliderAlpha.setMinimum(0) 

107 self.sliderAlpha.setMaximum(255) 

108 self.sliderAlpha.setValue(self.tableModel.alpha) 

109 self.sliderAlpha.valueChanged.connect(self.tableModel.setAlpha) 

110 

111 # Horizontal layout 

112 hLayout = QtWidgets.QHBoxLayout() 

113 hLayout.addWidget(self.lineEditFilter) 

114 hLayout.addSpacing(50) 

115 hLayout.addWidget(QtWidgets.QLabel('Colormap', self)) 

116 hLayout.addWidget(self.comboboxColormap) 

117 hLayout.addWidget(QtWidgets.QLabel('Alpha', self)) 

118 hLayout.addWidget(self.sliderAlpha) 

119 hLayout.addSpacing(50) 

120 hLayout.addWidget(self.pushButtonLoad) 

121 

122 # Vertical layout 

123 vLayout = QtWidgets.QVBoxLayout(self) 

124 vLayout.addLayout(hLayout) 

125 vLayout.addWidget(self.tableView) 

126 

127 # Recover layout from QSettings 

128 self.settings = QSettings() 

129 self.settings.beginGroup('MainWindow') 

130 self.resize(self.settings.value('size', QSize(800, 600), QSize)) 

131 self.comboboxColormap.setCurrentText(self.settings.value('colormap', 'plasma', str)) 

132 self.sliderAlpha.setValue(self.settings.value('alpha', 255, int)) 

133 self.settings.endGroup() 

134 

135 self.wplot = PlotWindow(wheel=wheel) 

136 self.wplot.show() 

137 self.tableModel.dataChanged.connect(self.wplot.canvas.draw) 

138 

139 self.wheel = wheel 

140 

141 def closeEvent(self, _) -> bool: 

142 self.settings.beginGroup('MainWindow') 

143 self.settings.setValue('size', self.size()) 

144 self.settings.setValue('colormap', self.tableModel.colormap) 

145 self.settings.setValue('alpha', self.tableModel.alpha) 

146 self.settings.endGroup() 

147 self.wplot.close() 

148 

149 def showEvent(self, a0: QShowEvent) -> None: 

150 super().showEvent(a0) 

151 self.activateWindow() 

152 

153 def contextMenu(self, pos: QPoint): 

154 idx = self.sender().logicalIndexAt(pos) 

155 action = self.pinAction 

156 action.setData(idx) 

157 action.setChecked(idx in self._pinnedColumns) 

158 menu = QMenu(self) 

159 menu.addAction(action) 

160 menu.exec(self.sender().mapToGlobal(pos)) 

161 

162 @pyqtSlot(bool) 

163 @pyqtSlot(bool, int) 

164 def pinColumn(self, pin: bool, idx: int | None = None): 

165 idx = idx if idx is not None else self.sender().data() 

166 if not pin and idx in self._pinnedColumns: 

167 self._pinnedColumns.remove(idx) 

168 if pin and idx not in self._pinnedColumns: 

169 self._pinnedColumns.append(idx) 

170 self.changeFilter(self.lineEditFilter.text()) 

171 

172 def changeFilter(self, string: str): 

173 headers = [ 

174 self.tableModel.headerData(x, Qt.Horizontal, Qt.DisplayRole).lower() 

175 for x in range(self.tableModel.columnCount()) 

176 ] 

177 tokens = [y.lower() for y in (x.strip() for x in string.split(',')) if len(y)] 

178 showAll = len(tokens) == 0 

179 for idx, column in enumerate(headers): 

180 show = showAll or any((t in column for t in tokens)) or idx in self._pinnedColumns 

181 self.tableView.setColumnHidden(idx, not show) 

182 

183 def loadFile(self): 

184 fileName, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'Open File', '', 'CSV Files (*.csv)') 

185 if len(fileName) == 0: 

186 return 

187 df = pd.read_csv(fileName) 

188 self.updateDataframe(df) 

189 

190 def updateDataframe(self, df: pd.DataFrame): 

191 # clear pinned columns 

192 self._pinnedColumns = [] 

193 

194 # try to identify and sort columns containing timestamps 

195 col_names = df.select_dtypes('number').columns 

196 df_interp = df[col_names].replace([-np.inf, np.inf], np.nan) 

197 df_interp = df_interp.interpolate(limit_direction='both') 

198 cols_mono = col_names[[df_interp[c].is_monotonic_increasing for c in col_names]] 

199 cols_mono = [c for c in cols_mono if df[c].nunique() > 1] 

200 cols_mono = df_interp[cols_mono].mean().sort_values().keys() 

201 for idx, col_name in enumerate(cols_mono): 

202 df.insert(idx, col_name, df.pop(col_name)) 

203 

204 # columns containing boolean values are sorted to the end 

205 # of those, columns containing 'pass' in their title will be sorted by number of False values 

206 col_names = df.columns 

207 cols_bool = list(df.select_dtypes(['bool', 'boolean']).columns) 

208 cols_pass = [c for c in cols_bool if 'pass' in c] 

209 cols_bool = [c for c in cols_bool if c not in cols_pass] # I know. Friday evening, brain is fried ... sorry. 

210 cols_pass = list((~df[cols_pass]).sum().sort_values().keys()) 

211 cols_bool += cols_pass 

212 for col_name in cols_bool: 

213 df = df.join(df.pop(col_name)) 

214 

215 # trial_no should always be the first column 

216 if 'trial_no' in col_names: 

217 df.insert(0, 'trial_no', df.pop('trial_no')) 

218 

219 # define columns that should be pinned by default 

220 for col in ['trial_no']: 

221 self._pinnedColumns.append(df.columns.get_loc(col)) 

222 

223 self.tableModel.setDataFrame(df) 

224 

225 def tv_double_clicked(self, index: QModelIndex): 

226 data = self.tableModel.dataFrame.iloc[index.row()] 

227 t0 = data['intervals_0'] 

228 t1 = data['intervals_1'] 

229 dt = t1 - t0 

230 if self.wheel: 

231 idx = np.searchsorted(self.wheel['re_ts'], np.array([t0 - dt / 10, t1 + dt / 10])) 

232 period = self.wheel['re_pos'][idx[0]:idx[1]] 

233 if period.size == 0: 

234 _logger.warning('No wheel data during trial #%i', index.row()) 

235 else: 

236 min_val, max_val = np.min(period), np.max(period) 

237 self.wplot.canvas.ax2.set_ylim(min_val - 1, max_val + 1) 

238 self.wplot.canvas.ax2.set_xlim(t0 - dt / 10, t1 + dt / 10) 

239 self.wplot.canvas.ax.set_xlim(t0 - dt / 10, t1 + dt / 10) 

240 self.wplot.setWindowTitle(f"Trial {data.get('trial_no', '?')}") 

241 self.wplot.canvas.draw() 

242 

243 

244def viewqc(qc=None, title=None, wheel=None): 

245 app = qt.create_app() 

246 app.setStyle('Fusion') 

247 QCoreApplication.setOrganizationName('International Brain Laboratory') 

248 QCoreApplication.setOrganizationDomain('internationalbrainlab.org') 

249 QCoreApplication.setApplicationName('QC Viewer') 

250 qcw = GraphWindow(wheel=wheel) 

251 qcw.setWindowTitle(title) 

252 if qc is not None: 

253 qcw.updateDataframe(qc) 

254 qcw.show() 

255 return qcw