Coverage for brainbox/plot_base.py: 67%

284 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1import numpy as np 

2import matplotlib.pyplot as plt 

3from matplotlib.image import NonUniformImage 

4import matplotlib 

5from matplotlib import cm 

6from iblutil.util import Bunch 

7 

8 

9axis_dict = {'x': 0, 'y': 1, 'z': 2} 

10 

11 

12class DefaultPlot(object): 

13 

14 def __init__(self, plot_type, data): 

15 """ 

16 Base class for organising data into a structure that can be easily used to create plots. 

17 The idea is that the dictionary is independent of plotting method and so can be fed into 

18 matplotlib, pyqtgraph, datoviz (or any plotting method of choice). 

19 

20 :param plot_type: type of plot (just for reference) 

21 :type plot_type: string 

22 :param data: dict of data containing at least 'x', 'y', and may additionally contain 'z' 

23 for 3D plots and 'c' 2D (image or scatter plots) with third variable represented by colour 

24 :type data: dict 

25 """ 

26 self.plot_type = plot_type 1jgfhecbd

27 self.data = data 1jgfhecbd

28 self.hlines = [] 1jgfhecbd

29 self.vlines = [] 1jgfhecbd

30 self.set_labels() 1jgfhecbd

31 

32 def add_lines(self, pos, orientation, lim=None, style='--', width=3, color='k'): 

33 """ 

34 Method to specify position and style of horizontal or vertical reference lines 

35 :param pos: position of line 

36 :param orientation: either 'v' for vertical line or 'h' for horizontal line 

37 :param lim: extent of lines 

38 :param style: line style 

39 :param width: line width 

40 :param color: line colour 

41 :return: 

42 """ 

43 if orientation == 'v': 1j

44 lim = self._set_default(lim, self.ylim) 1j

45 self.vlines.append(Bunch({'pos': pos, 'lim': lim, 'style': style, 'width': width, 1j

46 'color': color})) 

47 if orientation == 'h': 1j

48 lim = self._set_default(lim, self.xlim) 1j

49 self.hlines.append(Bunch({'pos': pos, 'lim': lim, 'style': style, 'width': width, 1j

50 'color': color})) 

51 

52 def set_labels(self, title=None, xlabel=None, ylabel=None, zlabel=None, clabel=None): 

53 """ 

54 Set labels for plot 

55 

56 :param title: title 

57 :param xlabel: x axis label 

58 :param ylabel: y axis label 

59 :param zlabel: z axis label 

60 :param clabel: cbar label 

61 :return: 

62 """ 

63 self.labels = Bunch({'title': title, 'xlabel': xlabel, 'ylabel': ylabel, 'zlabel': zlabel, 1jgfhecbd

64 'clabel': clabel}) 

65 

66 def set_xlim(self, xlim=None): 

67 """ 

68 Set xlim values 

69 

70 :param xlim: xlim values (min, max) supports tuple, list or np.array of len(2). If not 

71 specified will compute as min, max of y data 

72 """ 

73 self.xlim = self._set_lim('x', lim=xlim) 1jgfhecbd

74 

75 def set_ylim(self, ylim=None): 

76 """ 

77 Set ylim values 

78 

79 :param ylim: ylim values (min, max) supports tuple, list or np.array of len(2). If not 

80 specified will compute as min, max of y data 

81 """ 

82 self.ylim = self._set_lim('y', lim=ylim) 1jgfhecbd

83 

84 def set_zlim(self, zlim=None): 

85 """ 

86 Set zlim values 

87 

88 :param zlim: zlim values (min, max) supports tuple, list or np.array of len(2). If not 

89 specified will compute as min, max of z data 

90 """ 

91 self.zlim = self._set_lim('z', lim=zlim) 1j

92 

93 def set_clim(self, clim=None): 

94 """ 

95 Set clim values 

96 

97 :param clim: clim values (min, max) supports tuple, list or np.array of len(2). If not 

98 specified will compute as min, max of c data 

99 """ 

100 self.clim = self._set_lim('c', lim=clim) 1jgfhecbd

101 

102 def _set_lim(self, axis, lim=None): 

103 """ 

104 General function to set limits to either specified value if lim is not None or to nanmin, 

105 nanmin of data 

106 

107 :param axis: x, y, z or c 

108 :param lim: lim values (min, max) supports tuple, list or np.array of len(2) 

109 :return: 

110 """ 

111 if lim is not None: 1jgfecbd

112 assert len(lim) == 2 1jgb

113 else: 

114 lim = (np.nanmin(self.data[axis]), np.nanmax(self.data[axis])) 1jgfecbd

115 return lim 1jgfecbd

116 

117 def _set_default(self, val, default): 

118 """ 

119 General function to set value of attribute. If val is not None, the value of val will be 

120 returned otherwise default value will be returned 

121 

122 :param val: non-default value to set attribute to 

123 :param default: default value of attribute 

124 :return: 

125 """ 

126 if val is None: 1jgfhecbd

127 return default 1jgfhecbd

128 else: 

129 return val 1jgfecbd

130 

131 def convert2dict(self): 

132 """ 

133 Convert class object to dictionary 

134 

135 :return: dict with variables needed for plotting 

136 """ 

137 return vars(self) 1j

138 

139 

140class ImagePlot(DefaultPlot): 

141 def __init__(self, img, x=None, y=None, cmap=None): 

142 """ 

143 Class for organising data that will be used to create 2D image plots 

144 

145 :param img: 2D image data 

146 :param x: x coordinate of each image voxel in x dimension 

147 :param y: y coordinate of each image voxel in y dimension 

148 :param cmap: name of colormap to use 

149 """ 

150 

151 data = Bunch({'x': self._set_default(x, np.arange(img.shape[0])), 1g

152 'y': self._set_default(y, np.arange(img.shape[1])), 'c': img}) 

153 

154 # Make sure dimensions agree 

155 assert data['c'].shape[0] == data['x'].shape[0], 'dimensions must agree' 1g

156 assert data['c'].shape[1] == data['y'].shape[0], 'dimensions must agree' 1g

157 

158 # Initialise default plot class with data 

159 super().__init__('image', data) 1g

160 self.scale = None 1g

161 self.offset = None 1g

162 self.cmap = self._set_default(cmap, 'viridis') 1g

163 

164 self.set_xlim() 1g

165 self.set_ylim() 1g

166 self.set_clim() 1g

167 

168 def set_scale(self, scale=None): 

169 """ 

170 Set the scaling factor to apply to image (mainly for pyqtgraph implementation) 

171 

172 :param scale: scale values (xscale, yscale), supports tuple, list or np.array of len(2). 

173 If not specified will automatically compute from xlims/ylims and shape of data 

174 :return: 

175 """ 

176 # For pyqtgraph implementation 

177 if scale is not None: 1g

178 assert len(scale) == 2 

179 self.scale = self._set_default(scale, (self._get_scale('x'), self._get_scale('y'))) 1g

180 

181 def _get_scale(self, axis): 

182 """ 

183 Calculate scaling factor to apply along axis. Don't use directly, use set_scale() method 

184 

185 :param axis: 'x' or 'y' 

186 :return: 

187 """ 

188 if axis == 'x': 1g

189 lim = self.xlim 1g

190 else: 

191 lim = self.ylim 1g

192 lim = self._set_lim(axis, lim=lim) 1g

193 scale = (lim[1] - lim[0]) / self.data['c'].shape[axis_dict[axis]] 1g

194 return scale 1g

195 

196 def set_offset(self, offset=None): 

197 """ 

198 Set the offset to apply to the image (mainly for pyqtgraph implementation) 

199 

200 :param offset: offset values (xoffset, yoffset), supports tuple, list or np.array of len(2) 

201 If not specified will automatically compute from minimum of xlim and ylim 

202 :return: 

203 """ 

204 # For pyqtgraph implementation 

205 if offset is not None: 1g

206 assert len(offset) == 2 

207 self.offset = self._set_default(offset, (self._get_offset('x'), self._get_offset('y'))) 1g

208 

209 def _get_offset(self, axis): 

210 """ 

211 Calculate offset to apply to axis. Don't use directly, use set_offset() method 

212 :param axis: 'x' or 'y' 

213 :return: 

214 """ 

215 offset = np.nanmin(self.data[axis]) 1g

216 return offset 1g

217 

218 

219class ProbePlot(DefaultPlot): 

220 def __init__(self, img, x, y, cmap=None): 

221 """ 

222 Class for organising data that will be used to create 2D probe plots. Use function 

223 plot_base.arrange_channels2bank to prepare data in correct format before using this class 

224 

225 :param img: list of image data for each bank of probe 

226 :param x: list of x coordinate for each bank of probe 

227 :param y: list of y coordinate for each bank or probe 

228 :param cmap: name of cmap 

229 """ 

230 

231 # Make sure we have inputs as lists, can get input from arrange_channels2banks 

232 assert isinstance(img, list) 1h

233 assert isinstance(x, list) 1h

234 assert isinstance(y, list) 1h

235 

236 data = Bunch({'x': x, 'y': y, 'c': img}) 1h

237 super().__init__('probe', data) 1h

238 self.cmap = self._set_default(cmap, 'viridis') 1h

239 

240 self.set_xlim() 1h

241 self.set_ylim() 1h

242 self.set_clim() 1h

243 self.set_scale() 1h

244 self.set_offset() 1h

245 

246 def set_scale(self, idx=None, scale=None): 

247 if scale is not None: 1h

248 self.scale[idx] = scale 

249 else: 

250 self.scale = [(self._get_scale(i, 'x'), self._get_scale(i, 'y')) 1h

251 for i in range(len(self.data['x']))] 

252 

253 def _get_scale(self, idx, axis): 

254 lim = self._set_lim_list(axis, idx) 1h

255 scale = (lim[1] - lim[0]) / self.data['c'][idx].shape[axis_dict[axis]] 1h

256 return scale 1h

257 

258 def _set_lim_list(self, axis, idx, lim=None): 

259 if lim is not None: 1h

260 assert len(lim) == 2 

261 else: 

262 lim = (np.nanmin(self.data[axis][idx]), np.nanmax(self.data[axis][idx])) 1h

263 return lim 1h

264 

265 def set_offset(self, idx=None, offset=None): 

266 if offset is not None: 1h

267 self.offset[idx] = offset 

268 else: 

269 self.offset = [(np.min(self.data['x'][i]), np.min(self.data['y'][i])) 1h

270 for i in range(len(self.data['x']))] 

271 

272 def _set_lim(self, axis, lim=None): 

273 if lim is not None: 1h

274 assert (len(lim) == 2) 1h

275 else: 

276 data = np.concatenate([np.squeeze(np.ravel(d)) for d in self.data[axis]]).ravel() 1h

277 lim = (np.nanmin(data), np.nanmax(data)) 1h

278 return lim 1h

279 

280 

281class ScatterPlot(DefaultPlot): 

282 def __init__(self, x, y, z=None, c=None, cmap=None, plot_type='scatter'): 

283 """ 

284 Class for organising data that will be used to create scatter plots. Can be 2D or 3D (if 

285 z given). Can also represent variable through color by specifying c 

286 

287 :param x: x values for data 

288 :param y: y values for data 

289 :param z: z values for data 

290 :param c: values to use to represent color of scatter points 

291 :param cmap: name of colormap to use if c is given 

292 :param plot_type: 

293 """ 

294 data = Bunch({'x': x, 'y': y, 'z': z, 'c': c}) 1fecbd

295 

296 assert len(data['x']) == len(data['y']), 'dimensions must agree' 1fecbd

297 if data['z'] is not None: 1fecbd

298 assert len(data['z']) == len(data['x']), 'dimensions must agree' 

299 if data['c'] is not None: 1fecbd

300 assert len(data['c']) == len(data['x']), 'dimensions must agree' 1ecbd

301 

302 super().__init__(plot_type, data) 1fecbd

303 

304 self._set_init_style() 1fecbd

305 self.set_xlim() 1fecbd

306 self.set_ylim() 1fecbd

307 # If we have 3D data 

308 if data['z'] is not None: 1fecbd

309 self.set_zlim() 

310 # If we want colorbar associated with scatter plot 

311 self.set_clim() 1fecbd

312 self.cmap = self._set_default(cmap, 'viridis') 1fecbd

313 

314 def _set_init_style(self): 

315 """ 

316 Initialise defaults 

317 :return: 

318 """ 

319 self.set_color() 1ecbd

320 self.set_marker_size() 1ecbd

321 self.set_marker_type('o') 1ecbd

322 self.set_opacity() 1ecbd

323 self.set_line_color() 1ecbd

324 self.set_line_width() 1ecbd

325 self.set_line_style() 1ecbd

326 

327 def set_color(self, color=None): 

328 """ 

329 Color of scatter points. 

330 :param color: string e.g 'k', single RGB e,g [0,0,0] or np.array of RGB. In the latter case 

331 must give same no. of colours as datapoints i.e. len(np.array(RGB)) == len(data['x']) 

332 :return: 

333 """ 

334 self.color = self._set_default(color, 'b') 1ecbd

335 

336 def set_marker_size(self, marker_size=None): 

337 """ 

338 Size of each scatter point 

339 :param marker_size: int or np.array of int. In the latter case must give same no. of 

340 marker_size as datapoints i.e len(np.array(marker_size)) == len(data['x']) 

341 :return: 

342 """ 

343 self.marker_size = self._set_default(marker_size, None) 1fecbd

344 

345 def set_marker_type(self, marker_type=None): 

346 """ 

347 Shape of each scatter point 

348 

349 :param marker_type: 

350 :return: 

351 """ 

352 self.marker_type = self._set_default(marker_type, None) 1fecbd

353 

354 def set_opacity(self, opacity=None): 

355 """ 

356 Opacity of each scatter point 

357 

358 :param opacity: 

359 :return: 

360 """ 

361 self.opacity = self._set_default(opacity, 1) 1ecbd

362 

363 def set_line_color(self, line_color=None): 

364 """ 

365 Colour of edge of scatter point 

366 

367 :param line_color: string e.g 'k' or RGB e.g [0,0,0] 

368 :return: 

369 """ 

370 self.line_color = self._set_default(line_color, None) 1fecbd

371 

372 def set_line_width(self, line_width=None): 

373 """ 

374 Width of line on edge of scatter point 

375 

376 :param line_width: int 

377 :return: 

378 """ 

379 self.line_width = self._set_default(line_width, None) 1fecbd

380 

381 def set_line_style(self, line_style=None): 

382 """ 

383 Style of line on edge of scatter point 

384 

385 :param line_style: 

386 :return: 

387 """ 

388 self.line_style = self._set_default(line_style, '-') 1fecbd

389 

390 

391class LinePlot(ScatterPlot): 

392 def __init__(self, x, y): 

393 """ 

394 Class for organising data that will be used to create line plots. 

395 

396 :param x: x values for data 

397 :param y: y values for data 

398 """ 

399 super().__init__(x, y, plot_type='line') 1f

400 

401 self._set_init_style() 1f

402 self.set_xlim() 1f

403 self.set_ylim() 1f

404 

405 def _set_init_style(self): 

406 self.set_line_color('k') 1f

407 self.set_line_width(2) 1f

408 self.set_line_style() 1f

409 self.set_marker_size() 1f

410 self.set_marker_type() 1f

411 

412 

413def add_lines(ax, data, **kwargs): 

414 """ 

415 Function to add vertical and horizontal reference lines to matplotlib axis 

416 

417 :param ax: matplotlib axis 

418 :param data: dict of plot data 

419 :param kwargs: matplotlib keywords arguments associated with vlines/hlines 

420 :return: 

421 """ 

422 

423 for vline in data['vlines']: 

424 ax.vlines(vline['pos'], ymin=vline['lim'][0], ymax=vline['lim'][1], 

425 linestyles=vline['style'], linewidth=vline['width'], colors=vline['color'], 

426 **kwargs) 

427 

428 for hline in data['hlines']: 

429 ax.hlines(hline['pos'], xmin=hline['lim'][0], xmax=hline['lim'][1], 

430 linestyles=hline['style'], linewidth=hline['width'], colors=hline['color'], 

431 **kwargs) 

432 

433 return ax 

434 

435 

436def plot_image(data, ax=None, show_cbar=True, fig_kwargs=dict(), line_kwargs=dict(), 

437 img_kwargs=dict()): 

438 """ 

439 Function to create matplotlib plot from ImagePlot object 

440 

441 :param data: ImagePlot object, either class or dict 

442 :param ax: matplotlib axis to plot on, if None, will create figure 

443 :param show_cbar: whether or not to display colour bar 

444 :param fig_kwargs: dict of matplotlib keywords associcated with plt.subplots e.g can be 

445 fig size, tight layout etc. 

446 :param line_kwargs: dict of matplotlib keywords associated with ax.hlines/ax.vlines 

447 :param img_kwargs: dict of matplotlib keywords associated with matplotlib.imshow 

448 :return: matplotlib axis and figure handles 

449 """ 

450 if not isinstance(data, dict): 

451 data = data.convert2dict() 

452 

453 if not ax: 

454 fig, ax = plt.subplots(**fig_kwargs) 

455 else: 

456 fig = plt.gcf() 

457 

458 img = ax.imshow(data['data']['c'].T, extent=np.r_[data['xlim'], data['ylim']], 

459 cmap=data['cmap'], vmin=data['clim'][0], vmax=data['clim'][1], origin='lower', 

460 aspect='auto', **img_kwargs) 

461 

462 ax.set_xlim(data['xlim'][0], data['xlim'][1]) 

463 ax.set_ylim(data['ylim'][0], data['ylim'][1]) 

464 ax.set_xlabel(data['labels']['xlabel']) 

465 ax.set_ylabel(data['labels']['ylabel']) 

466 ax.set_title(data['labels']['title']) 

467 

468 if show_cbar: 

469 cbar = fig.colorbar(img, ax=ax) 

470 cbar.set_label(data['labels']['clabel']) 

471 

472 ax = add_lines(ax, data, **line_kwargs) 

473 

474 return ax, fig 

475 

476 

477def plot_scatter(data, ax=None, show_cbar=True, fig_kwargs=dict(), line_kwargs=dict(), 

478 scat_kwargs=None): 

479 """ 

480 Function to create matplotlib plot from ScatterPlot object. If data['colors'] is given for each 

481 data point it will override automatic colours that would be generated from data['data']['c'] 

482 

483 :param data: ScatterPlot object, either class or dict 

484 :param ax: matplotlib axis to plot on, if None, will create figure 

485 :param show_cbar: whether or not to display colour bar 

486 :param fig_kwargs: dict of matplotlib keywords associcated with plt.subplots e.g can be 

487 fig size, tight layout etc. 

488 :param line_kwargs: dict of matplotlib keywords associated with ax.hlines/ax.vlines 

489 :param scat_kwargs: dict of matplotlib keywords associated with matplotlib.scatter 

490 :return: matplotlib axis and figure handles 

491 """ 

492 scat_kwargs = scat_kwargs or dict() 

493 if not isinstance(data, dict): 

494 data = data.convert2dict() 

495 

496 if not ax: 

497 fig, ax = plt.subplots(**fig_kwargs) 

498 else: 

499 fig = plt.gcf() 

500 

501 # Single color for all points 

502 if data['data']['c'] is None: 

503 scat = ax.scatter(x=data['data']['x'], y=data['data']['y'], c=data['color'], 

504 s=data['marker_size'], marker=data['marker_type'], 

505 edgecolors=data['line_color'], linewidths=data['line_width'], 

506 **scat_kwargs) 

507 else: 

508 # Colour for each point specified 

509 if len(data['color']) == len(data['data']['x']): 

510 if np.max(data['color']) > 1: 

511 data['color'] = data['color'] / 255 

512 

513 scat = ax.scatter(x=data['data']['x'], y=data['data']['y'], c=data['color'], 

514 s=data['marker_size'], marker=data['marker_type'], 

515 edgecolors=data['line_color'], linewidths=data['line_width'], 

516 **scat_kwargs) 

517 if show_cbar: 

518 norm = matplotlib.colors.Normalize(vmin=data['clim'][0], vmax=data['clim'][1], 

519 clip=True) 

520 cbar = fig.colorbar(cm.ScalarMappable(norm=norm, cmap=data['cmap']), ax=ax) 

521 cbar.set_label(data['labels']['clabel']) 

522 # Automatically generate from c data 

523 else: 

524 scat = ax.scatter(x=data['data']['x'], y=data['data']['y'], c=data['data']['c'], 

525 s=data['marker_size'], marker=data['marker_type'], cmap=data['cmap'], 

526 vmin=data['clim'][0], vmax=data['clim'][1], 

527 edgecolors=data['line_color'], linewidths=data['line_width'], 

528 **scat_kwargs) 

529 if show_cbar: 

530 cbar = fig.colorbar(scat, ax=ax) 

531 cbar.set_label(data['labels']['clabel']) 

532 

533 ax = add_lines(ax, data, **line_kwargs) 

534 

535 ax.set_xlim(data['xlim'][0], data['xlim'][1]) 

536 ax.set_ylim(data['ylim'][0], data['ylim'][1]) 

537 ax.set_xlabel(data['labels']['xlabel']) 

538 ax.set_ylabel(data['labels']['ylabel']) 

539 ax.set_title(data['labels']['title']) 

540 

541 return ax, fig 

542 

543 

544def plot_probe(data, ax=None, show_cbar=True, make_pretty=True, fig_kwargs=dict(), 

545 line_kwargs=dict()): 

546 """ 

547 Function to create matplotlib plot from ProbePlot object 

548 

549 :param data: ProbePlot object, either class or dict 

550 :param ax: matplotlib axis to plot on, if None, will create figure 

551 :param show_cbar: whether or not to display colour bar 

552 :param make_pretty: get rid of spines on axis 

553 :param fig_kwargs: dict of matplotlib keywords associcated with plt.subplots e.g can be 

554 fig size, tight layout etc. 

555 :param line_kwargs: dict of matplotlib keywords associated with ax.hlines/ax.vlines 

556 :return: matplotlib axis and figure handles 

557 """ 

558 

559 if not isinstance(data, dict): 

560 data = data.convert2dict() 

561 

562 if not ax: 

563 fig, ax = plt.subplots(figsize=(2, 8), **fig_kwargs) 

564 else: 

565 fig = plt.gcf() 

566 

567 for (x, y, dat) in zip(data['data']['x'], data['data']['y'], data['data']['c']): 

568 im = NonUniformImage(ax, interpolation='nearest', cmap=data['cmap']) 

569 im.set_clim(data['clim'][0], data['clim'][1]) 

570 im.set_data(x, y, dat.T) 

571 ax.add_image(im) 

572 

573 ax.set_xlim(data['xlim'][0], data['xlim'][1]) 

574 ax.set_ylim(data['ylim'][0], data['ylim'][1]) 

575 ax.set_xlabel(data['labels']['xlabel']) 

576 ax.set_ylabel(data['labels']['ylabel']) 

577 ax.set_title(data['labels']['title']) 

578 

579 if make_pretty: 

580 ax.get_xaxis().set_visible(False) 

581 ax.spines['right'].set_visible(False) 

582 ax.spines['top'].set_visible(False) 

583 ax.spines['bottom'].set_visible(False) 

584 

585 if show_cbar: 

586 cbar = fig.colorbar(im, orientation="horizontal", pad=0.02, ax=ax) 

587 cbar.set_label(data['labels']['clabel']) 

588 

589 ax = add_lines(ax, data, **line_kwargs) 

590 

591 return ax, fig 

592 

593 

594def plot_line(data, ax=None, fig_kwargs=dict(), line_kwargs=dict()): 

595 """ 

596 Function to create matplotlib plot from LinePlot object 

597 

598 :param data: LinePlot object either class or dict 

599 :param ax: matplotlib axis to plot on 

600 :param fig_kwargs: dict of matplotlib keywords associcated with plt.subplots e.g can be 

601 fig size, tight layout etc. 

602 :param line_kwargs: dict of matplotlib keywords associated with ax.hlines/ax.vlines 

603 :return: matplotlib axis and figure handles 

604 """ 

605 if not isinstance(data, dict): 

606 data = data.convert2dict() 

607 

608 if not ax: 

609 fig, ax = plt.subplots(**fig_kwargs) 

610 else: 

611 fig = plt.gcf() 

612 

613 ax.plot(data['data']['x'], data['data']['y'], color=data['line_color'], 

614 linestyle=data['line_style'], linewidth=data['line_width'], marker=data['marker_type'], 

615 markersize=data['marker_size']) 

616 ax = add_lines(ax, data, **line_kwargs) 

617 

618 ax.set_xlim(data['xlim'][0], data['xlim'][1]) 

619 ax.set_ylim(data['ylim'][0], data['ylim'][1]) 

620 ax.set_xlabel(data['labels']['xlabel']) 

621 ax.set_ylabel(data['labels']['ylabel']) 

622 ax.set_title(data['labels']['title']) 

623 

624 return ax, fig 

625 

626 

627def scatter_xyc_plot(x, y, c, cmap=None, clim=None, rgb=False): 

628 """ 

629 General function for preparing x y scatter plot with third variable encoded by colour of points 

630 :param x: 

631 :param y: 

632 :param c: 

633 :param cmap: 

634 :param clim: 

635 :param rgb: Whether to compute rgb (set True when preparing pyqtgraph data) 

636 :return: 

637 """ 

638 

639 data = ScatterPlot(x=x, y=y, c=c, cmap=cmap) 1cbd

640 data.set_clim(clim=clim) 1cbd

641 if rgb: 1cbd

642 norm = matplotlib.colors.Normalize(vmin=data.clim[0], vmax=data.clim[1], clip=True) 1cb

643 mapper = cm.ScalarMappable(norm=norm, cmap=plt.get_cmap(cmap)) 1cb

644 cluster_color = np.array([mapper.to_rgba(col) for col in c]) 1cb

645 data.set_color(color=cluster_color) 1cb

646 

647 return data 1cbd

648 

649 

650def arrange_channels2banks(data, chn_coords, depth=None, pad=True, x_offset=1): 

651 """ 

652 Rearranges data on channels so it matches geometry of probe. e.g For Neuropixel 2.0 rearranges 

653 channels into 4 banks with checkerboard pattern 

654 

655 :param data: data on channels 

656 :param chn_coords: local coordinates of channels on probe 

657 :param depth: depth location of electrode (for example could be relative to bregma). If none 

658 given will stay in probe local coordinates 

659 :param pad: for matplotlib implementation with NonUniformImage we need to surround our data 

660 with nans so that it shows as finite display 

661 :param x_offset: spacing between banks in x direction 

662 :return: list, data, x position and y position for each bank 

663 """ 

664 data_bank = [] 1lik

665 x_bank = [] 1lik

666 y_bank = [] 1lik

667 

668 if depth is None: 1lik

669 depth = chn_coords[:, 1] 1li

670 

671 for iX, x in enumerate(np.unique(chn_coords[:, 0])): 1lik

672 bnk_idx = np.where(chn_coords[:, 0] == x)[0] 1lik

673 bnk_data = data[bnk_idx, np.newaxis].T 1lik

674 # This is a hack! Although data is 1D we give it two x coords so we can correctly set 

675 # scale and extent (compatible with pyqtgraph and matplotlib.imshow) 

676 # For matplotlib.image.Nonuniformimage must use pad=True option 

677 bnk_x = np.array((iX * x_offset, (iX + 1) * x_offset)) 1lik

678 bnk_y = depth[bnk_idx] 1lik

679 if pad: 1lik

680 # pad data in y direction 

681 bnk_data = np.insert(bnk_data, 0, np.nan) 1ik

682 bnk_data = np.append(bnk_data, np.nan) 1ik

683 # pad data in x direction 

684 bnk_data = bnk_data[:, np.newaxis].T 1ik

685 bnk_data = np.insert(bnk_data, 0, np.full(bnk_data.shape[1], np.nan), axis=0) 1ik

686 bnk_data = np.append(bnk_data, np.full((1, bnk_data.shape[1]), np.nan), axis=0) 1ik

687 

688 # pad the x values 

689 bnk_x = np.arange(iX * x_offset, (iX + 3) * x_offset, x_offset) 1ik

690 

691 # pad the y values 

692 diff = np.diff(bnk_y) 1ik

693 diff = diff[np.nonzero(diff)] 1ik

694 

695 bnk_y = np.insert(bnk_y, 0, bnk_y[0] - np.abs(diff[0])) 1ik

696 bnk_y = np.append(bnk_y, bnk_y[-1] + np.abs(diff[-1])) 1ik

697 

698 data_bank.append(bnk_data) 1lik

699 x_bank.append(bnk_x) 1lik

700 y_bank.append(bnk_y) 1lik

701 

702 return data_bank, x_bank, y_bank 1lik