Coverage for ibllib/qc/task_qc_viewer/ViewEphysQC.py: 25%

129 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1"""An interactive PyQT QC data frame.""" 

2import logging 

3 

4from PyQt5 import QtCore, QtWidgets 

5from matplotlib.figure import Figure 

6from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT 

7import pandas as pd 

8import numpy as np 

9 

10from ibllib.misc import qt 

11 

12_logger = logging.getLogger(__name__) 

13 

14 

15class DataFrameModel(QtCore.QAbstractTableModel): 

16 DtypeRole = QtCore.Qt.UserRole + 1000 

17 ValueRole = QtCore.Qt.UserRole + 1001 

18 

19 def __init__(self, df=pd.DataFrame(), parent=None): 

20 super(DataFrameModel, self).__init__(parent) 

21 self._dataframe = df 

22 

23 def setDataFrame(self, dataframe): 

24 self.beginResetModel() 

25 self._dataframe = dataframe.copy() 

26 self.endResetModel() 

27 

28 def dataFrame(self): 

29 return self._dataframe 

30 

31 dataFrame = QtCore.pyqtProperty(pd.DataFrame, fget=dataFrame, fset=setDataFrame) 

32 

33 @QtCore.pyqtSlot(int, QtCore.Qt.Orientation, result=str) 

34 def headerData(self, section: int, orientation: QtCore.Qt.Orientation, 

35 role: int = QtCore.Qt.DisplayRole): 

36 if role == QtCore.Qt.DisplayRole: 

37 if orientation == QtCore.Qt.Horizontal: 

38 return self._dataframe.columns[section] 

39 else: 

40 return str(self._dataframe.index[section]) 

41 return QtCore.QVariant() 

42 

43 def rowCount(self, parent=QtCore.QModelIndex()): 

44 if parent.isValid(): 

45 return 0 

46 return len(self._dataframe.index) 

47 

48 def columnCount(self, parent=QtCore.QModelIndex()): 

49 if parent.isValid(): 

50 return 0 

51 return self._dataframe.columns.size 

52 

53 def data(self, index, role=QtCore.Qt.DisplayRole): 

54 if (not index.isValid() or not (0 <= index.row() < self.rowCount() and 

55 0 <= index.column() < self.columnCount())): 

56 return QtCore.QVariant() 

57 row = self._dataframe.index[index.row()] 

58 col = self._dataframe.columns[index.column()] 

59 dt = self._dataframe[col].dtype 

60 

61 val = self._dataframe.iloc[row][col] 

62 if role == QtCore.Qt.DisplayRole: 

63 return str(val) 

64 elif role == DataFrameModel.ValueRole: 

65 return val 

66 if role == DataFrameModel.DtypeRole: 

67 return dt 

68 return QtCore.QVariant() 

69 

70 def roleNames(self): 

71 roles = { 

72 QtCore.Qt.DisplayRole: b'display', 

73 DataFrameModel.DtypeRole: b'dtype', 

74 DataFrameModel.ValueRole: b'value' 

75 } 

76 return roles 

77 

78 def sort(self, col, order): 

79 """ 

80 Sort table by given column number. 

81 

82 :param col: the column number selected (between 0 and self._dataframe.columns.size) 

83 :param order: the order to be sorted, 0 is descending; 1, ascending 

84 :return: 

85 """ 

86 self.layoutAboutToBeChanged.emit() 

87 col_name = self._dataframe.columns.values[col] 

88 # print('sorting by ' + col_name) 

89 self._dataframe.sort_values(by=col_name, ascending=not order, inplace=True) 

90 self._dataframe.reset_index(inplace=True, drop=True) 

91 self.layoutChanged.emit() 

92 

93 

94class PlotCanvas(FigureCanvasQTAgg): 

95 

96 def __init__(self, parent=None, width=5, height=4, dpi=100, wheel=None): 

97 fig = Figure(figsize=(width, height), dpi=dpi) 

98 

99 FigureCanvasQTAgg.__init__(self, fig) 

100 self.setParent(parent) 

101 

102 FigureCanvasQTAgg.setSizePolicy( 

103 self, 

104 QtWidgets.QSizePolicy.Expanding, 

105 QtWidgets.QSizePolicy.Expanding) 

106 FigureCanvasQTAgg.updateGeometry(self) 

107 if wheel: 

108 self.ax, self.ax2 = fig.subplots( 

109 2, 1, gridspec_kw={'height_ratios': [2, 1]}, sharex=True) 

110 else: 

111 self.ax = fig.add_subplot(111) 

112 self.draw() 

113 

114 

115class PlotWindow(QtWidgets.QWidget): 

116 def __init__(self, parent=None, wheel=None): 

117 QtWidgets.QWidget.__init__(self, parent=None) 

118 self.canvas = PlotCanvas(wheel=wheel) 

119 self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting 

120 self.vbl.addWidget(self.canvas) 

121 self.setLayout(self.vbl) 

122 self.vbl.addWidget(NavigationToolbar2QT(self.canvas, self)) 

123 

124 

125class GraphWindow(QtWidgets.QWidget): 

126 def __init__(self, parent=None, wheel=None): 

127 QtWidgets.QWidget.__init__(self, parent=parent) 

128 vLayout = QtWidgets.QVBoxLayout(self) 

129 hLayout = QtWidgets.QHBoxLayout() 

130 self.pathLE = QtWidgets.QLineEdit(self) 

131 hLayout.addWidget(self.pathLE) 

132 self.loadBtn = QtWidgets.QPushButton("Select File", self) 

133 hLayout.addWidget(self.loadBtn) 

134 vLayout.addLayout(hLayout) 

135 self.pandasTv = QtWidgets.QTableView(self) 

136 vLayout.addWidget(self.pandasTv) 

137 self.loadBtn.clicked.connect(self.load_file) 

138 self.pandasTv.setSortingEnabled(True) 

139 self.pandasTv.doubleClicked.connect(self.tv_double_clicked) 

140 self.wplot = PlotWindow(wheel=wheel) 

141 self.wplot.show() 

142 self.wheel = wheel 

143 

144 def load_file(self): 

145 fileName, _ = QtWidgets.QFileDialog.getOpenFileName( 

146 self, "Open File", "", "CSV Files (*.csv)") 

147 self.pathLE.setText(fileName) 

148 df = pd.read_csv(fileName) 

149 self.update_df(df) 

150 

151 def update_df(self, df): 

152 model = DataFrameModel(df) 

153 self.pandasTv.setModel(model) 

154 self.wplot.canvas.draw() 

155 

156 def tv_double_clicked(self): 

157 df = self.pandasTv.model()._dataframe 

158 ind = self.pandasTv.currentIndex() 

159 start = df.loc[ind.row()]['intervals_0'] 

160 finish = df.loc[ind.row()]['intervals_1'] 

161 dt = finish - start 

162 if self.wheel: 

163 idx = np.searchsorted( 

164 self.wheel['re_ts'], np.array([start - dt / 10, finish + dt / 10])) 

165 period = self.wheel['re_pos'][idx[0]:idx[1]] 

166 if period.size == 0: 

167 _logger.warning('No wheel data during trial #%i', ind.row()) 

168 else: 

169 min_val, max_val = np.min(period), np.max(period) 

170 self.wplot.canvas.ax2.set_ylim(min_val - 1, max_val + 1) 

171 self.wplot.canvas.ax2.set_xlim(start - dt / 10, finish + dt / 10) 

172 self.wplot.canvas.ax.set_xlim(start - dt / 10, finish + dt / 10) 

173 

174 self.wplot.canvas.draw() 

175 

176 

177def viewqc(qc=None, title=None, wheel=None): 

178 qt.create_app() 

179 qcw = GraphWindow(wheel=wheel) 

180 qcw.setWindowTitle(title) 

181 if qc is not None: 

182 qcw.update_df(qc) 

183 qcw.show() 

184 return qcw