Coverage for brainbox/ephys_plots.py: 21%

255 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1import numpy as np 

2from matplotlib import cm 

3import matplotlib.pyplot as plt 

4from brainbox.plot_base import (ImagePlot, ScatterPlot, ProbePlot, LinePlot, plot_line, 

5 plot_image, plot_probe, plot_scatter, arrange_channels2banks) 

6from brainbox.processing import compute_cluster_average 

7from iblutil.numerical import bincount2D 

8from iblatlas.regions import BrainRegions 

9 

10 

11def image_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords=None, chn_inds=None, freq_range=(0, 300), 

12 avg_across_depth=False, clim=None, cmap='viridis', display=False, title=None, **kwargs): 

13 """ 

14 Prepare data for 2D image plot of LFP power spectrum along depth of probe 

15 

16 :param lfp_power: 

17 :param lfp_freq: 

18 :param chn_depths: 

19 :param chn_inds: 

20 :param freq_range: 

21 :param avg_across_depth: Whether to average across channels at same depth 

22 :param cmap: 

23 :param display: generate figure 

24 :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects 

25 """ 

26 

27 ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)' 

28 title = title or 'LFP Power Spectrum' 

29 

30 y = np.arange(lfp_power.shape[1]) if chn_coords is None else chn_coords[:, 1] 

31 chn_inds = np.arange(lfp_power.shape[1]) if chn_inds is None else chn_inds 

32 

33 freq_idx = np.where((lfp_freq >= freq_range[0]) & (lfp_freq < freq_range[1]))[0] 

34 freqs = lfp_freq[freq_idx] 

35 lfp = np.take(lfp_power[freq_idx], chn_inds, axis=1) 

36 lfp_db = 10 * np.log10(lfp) 

37 lfp_db[np.isinf(lfp_db)] = np.nan 

38 x = freqs 

39 

40 # Average across channels that are at the same depth 

41 if avg_across_depth: 

42 chn_depth, chn_idx, chn_count = np.unique(y, return_index=True, 

43 return_counts=True) 

44 chn_idx_eq = np.copy(chn_idx) 

45 chn_idx_eq[np.where(chn_count == 2)] += 1 

46 

47 lfp_db = np.apply_along_axis(lambda a: np.mean([a[chn_idx], a[chn_idx_eq]], axis=0), 1, 

48 lfp_db) 

49 

50 x = freqs 

51 y = chn_depth 

52 

53 data = ImagePlot(lfp_db, x=x, y=y, cmap=cmap) 

54 data.set_labels(title=title, xlabel='Frequency (Hz)', 

55 ylabel=ylabel, clabel='LFP Power (dB)') 

56 clim = clim or np.quantile(lfp_db, [0.1, 0.9]) 

57 data.set_clim(clim=clim) 

58 

59 if display: 

60 ax, fig = plot_image(data.convert2dict(), **kwargs) 

61 return data.convert2dict(), fig, ax 

62 

63 return data 

64 

65 

66def image_rms_plot(rms_amps, rms_times, chn_coords=None, chn_inds=None, avg_across_depth=False, 

67 median_subtract=True, clim=None, cmap='plasma', band='AP', display=False, title=None, **kwargs): 

68 """ 

69 Prepare data for 2D image plot of RMS data along depth of probe 

70 

71 :param rms_amps: 

72 :param rms_times: 

73 :param chn_coords: 

74 :param chn_inds: 

75 :param avg_across_depth: Whether to average across channels at same depth 

76 :param median_subtract: Whether to apply median subtraction correction 

77 :param cmap: 

78 :param band: Frequency band of rms data, can be either 'LF' or 'AP' 

79 :param display: generate figure 

80 :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects 

81 """ 

82 

83 ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)' 

84 title = title or f'{band} RMS' 

85 chn_inds = np.arange(rms_amps.shape[1]) if chn_inds is None else chn_inds 

86 y = np.arange(rms_amps.shape[1]) if chn_coords is None else chn_coords[:, 1] 

87 

88 rms = rms_amps[:, chn_inds] 

89 rms = 10 * np.log10(rms) 

90 x = rms_times 

91 

92 if avg_across_depth: 

93 chn_depth, chn_idx, chn_count = np.unique(y, return_index=True, return_counts=True) 

94 chn_idx_eq = np.copy(chn_idx) 

95 chn_idx_eq[np.where(chn_count == 2)] += 1 

96 rms = np.apply_along_axis(lambda a: np.mean([a[chn_idx], a[chn_idx_eq]], axis=0), 1, rms) 

97 y = chn_depth 

98 

99 if median_subtract: 

100 median = np.mean(np.apply_along_axis(lambda a: np.median(a), 1, rms)) 

101 rms = np.apply_along_axis(lambda a: a - np.median(a), 1, rms) + median 

102 

103 data = ImagePlot(rms, x=x, y=y, cmap=cmap) 

104 data.set_labels(title=title, xlabel='Time (s)', ylabel=ylabel, clabel=f'{band} RMS (dB)') 

105 clim = clim or np.quantile(rms, [0.1, 0.9]) 

106 data.set_clim(clim=clim) 

107 

108 if display: 

109 ax, fig = plot_image(data.convert2dict(), **kwargs) 

110 return data.convert2dict(), fig, ax 

111 

112 return data 

113 

114 

115def scatter_raster_plot(spike_amps, spike_depths, spike_times, n_amp_bins=10, cmap='BuPu', 

116 subsample_factor=100, display=False, title=None, **kwargs): 

117 """ 

118 Prepare data for 2D raster plot of spikes with colour and size indicative of spike amplitude 

119 

120 :param spike_amps: 

121 :param spike_depths: 

122 :param spike_times: 

123 :param n_amp_bins: no. of colour and size bins into which to split amplitude data 

124 :param cmap: 

125 :param subsample_factor: factor by which to subsample data when too many points for efficient 

126 display 

127 :param display: generate figure 

128 :return: ScatterPlot object, if display=True also returns matplotlib fig and ax objects 

129 """ 

130 

131 title = title or 'Spike times vs Spike depths' 

132 amp_range = np.quantile(spike_amps, [0, 0.9]) 

133 amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins) 

134 color_bin = np.linspace(0.0, 1.0, n_amp_bins + 1) 

135 colors = (cm.get_cmap(cmap)(color_bin)[np.newaxis, :, :3][0]) 

136 

137 spike_amps = spike_amps[0:-1:subsample_factor] 

138 spike_colors = np.zeros((spike_amps.size, 3)) 

139 spike_size = np.zeros(spike_amps.size) 

140 for iA in range(amp_bins.size): 

141 if iA == (amp_bins.size - 1): 

142 idx = np.where(spike_amps > amp_bins[iA])[0] 

143 # Make saturated spikes the darkest colour 

144 spike_colors[idx] = colors[-1] 

145 else: 

146 idx = np.where((spike_amps > amp_bins[iA]) & (spike_amps <= amp_bins[iA + 1]))[0] 

147 spike_colors[idx] = [*colors[iA]] 

148 

149 spike_size[idx] = iA / (n_amp_bins / 8) 

150 

151 data = ScatterPlot(x=spike_times[0:-1:subsample_factor], y=spike_depths[0:-1:subsample_factor], 

152 c=spike_amps * 1e6, cmap='BuPu') 

153 data.set_ylim((0, 3840)) 

154 data.set_color(color=spike_colors) 

155 data.set_clim(clim=amp_range * 1e6) 

156 data.set_marker_size(marker_size=spike_size) 

157 data.set_labels(title=title, xlabel='Time (s)', 

158 ylabel='Distance from probe tip (um)', clabel='Spike amplitude (uV)') 

159 

160 if display: 

161 ax, fig = plot_scatter(data.convert2dict(), **kwargs) 

162 return data.convert2dict(), fig, ax 

163 

164 return data 

165 

166 

167def image_fr_plot(spike_depths, spike_times, chn_coords, t_bin=0.05, d_bin=5, cmap='binary', 

168 display=False, title=None, **kwargs): 

169 """ 

170 Prepare data 2D raster plot of firing rate across recording 

171 

172 :param spike_depths: 

173 :param spike_times: 

174 :param chn_coords: 

175 :param t_bin: time bin to average across (see also brainbox.processing.bincount2D) 

176 :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D) 

177 :param cmap: 

178 :param display: generate figure 

179 :return: ImagePlot object, if display=True also returns matplotlib fig and ax objects 

180 """ 

181 

182 title = title or 'Firing Rate' 

183 n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin, 

184 ylim=[0, np.max(chn_coords[:, 1])]) 

185 fr = n.T / t_bin 

186 

187 data = ImagePlot(fr, x=x, y=y, cmap=cmap) 

188 data.set_labels(title=title, xlabel='Time (s)', 

189 ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)') 

190 data.set_clim(clim=(np.min(np.mean(fr, axis=0)), np.max(np.mean(fr, axis=0)))) 

191 if display: 

192 ax, fig = plot_image(data.convert2dict(), **kwargs) 

193 return data.convert2dict(), fig, ax 

194 

195 return data 

196 

197 

198def image_crosscorr_plot(spike_depths, spike_times, chn_coords, t_bin=0.05, d_bin=40, 

199 cmap='viridis', display=False, title=None, **kwargs): 

200 """ 

201 Prepare data for 2D cross correlation plot of data across depth 

202 

203 :param spike_depths: 

204 :param spike_times: 

205 :param chn_coords: 

206 :param t_bin: t_bin: time bin to average across (see also brainbox.processing.bincount2D) 

207 :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D) 

208 :param cmap: 

209 :param display: generate figure 

210 :return: ImagePlot object, if display=True also returns matploltlib fig and ax objects 

211 """ 

212 

213 title = title or 'Correlation' 

214 n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin, 

215 ylim=[0, np.max(chn_coords[:, 1])]) 

216 corr = np.corrcoef(n) 

217 corr[np.isnan(corr)] = 0 

218 

219 data = ImagePlot(corr, x=y, y=y, cmap=cmap) 

220 data.set_labels(title=title, xlabel='Distance from probe tip (um)', 

221 ylabel='Distance from probe tip (um)', clabel='Correlation') 

222 

223 if display: 

224 ax, fig = plot_image(data.convert2dict(), **kwargs) 

225 return data.convert2dict(), fig, ax 

226 

227 return data 

228 

229 

230def scatter_amp_depth_fr_plot(spike_amps, spike_clusters, spike_depths, spike_times, cmap='hot', 

231 display=False, title=None, **kwargs): 

232 """ 

233 Prepare data for 2D scatter plot of cluster depth vs cluster amp with colour indicating cluster 

234 firing rate 

235 

236 :param spike_amps: 

237 :param spike_clusters: 

238 :param spike_depths: 

239 :param spike_times: 

240 :param cmap: 

241 :param display: generate figure 

242 :return: ScatterPlot object, if display=True also returns matplotlib fig and ax objects 

243 """ 

244 

245 title = title or 'Cluster depth vs amp vs firing rate' 

246 

247 # TODO use pandas here instead, much quicker 

248 cluster, cluster_depth, n_cluster = compute_cluster_average(spike_clusters, spike_depths) 

249 _, cluster_amp, _ = compute_cluster_average(spike_clusters, spike_amps) 

250 cluster_amp = cluster_amp * 1e6 

251 cluster_fr = n_cluster / np.max(spike_times) 

252 

253 data = ScatterPlot(x=cluster_amp, y=cluster_depth, c=cluster_fr, cmap=cmap) 

254 data.set_xlim((0.9 * np.min(cluster_amp), 1.1 * np.max(cluster_amp))) 

255 data.set_labels(title=title, xlabel='Cluster Amplitude (uV)', ylabel='Distance from probe tip (um)', 

256 clabel='Firing rate (Hz)') 

257 if display: 

258 ax, fig = plot_scatter(data.convert2dict(), **kwargs) 

259 return data.convert2dict(), fig, ax 

260 

261 return data 

262 

263 

264def probe_lfp_spectrum_plot(lfp_power, lfp_freq, chn_coords, chn_inds, freq_range=(0, 4), 

265 display=False, pad=True, x_offset=1, **kwargs): 

266 """ 

267 Prepare data for 2D probe plot of LFP power spectrum along depth of probe 

268 

269 :param lfp_power: 

270 :param lfp_freq: 

271 :param chn_coords: 

272 :param chn_inds: 

273 :param freq_range: 

274 :param display: 

275 :param pad: whether to add nans around the individual image plots. For matplotlib use pad=True, 

276 for pyqtgraph use pad=False 

277 :param x_offset: Distance between the channel banks in x direction 

278 :return: ProbePlot object, if display=True also returns matplotlib fig and ax objects 

279 """ 

280 

281 freq_idx = np.where((lfp_freq >= freq_range[0]) & (lfp_freq < freq_range[1]))[0] 

282 lfp = np.take(lfp_power[freq_idx], chn_inds, axis=1) 

283 lfp_db = 10 * np.log10(lfp) 

284 lfp_db[np.isinf(lfp_db)] = np.nan 

285 lfp_db = np.mean(lfp_db, axis=0) 

286 

287 data_bank, x_bank, y_bank = arrange_channels2banks(lfp_db, chn_coords, depth=None, 

288 pad=pad, x_offset=x_offset) 

289 data = ProbePlot(data_bank, x=x_bank, y=y_bank) 

290 data.set_labels(ylabel='Distance from probe tip (um)', clabel='PSD 0-4 Hz (dB)') 

291 clim = np.nanquantile(np.concatenate([np.squeeze(np.ravel(d)) for d in data_bank]).ravel(), 

292 [0.1, 0.9]) 

293 data.set_clim(clim) 

294 

295 if display: 

296 ax, fig = plot_probe(data.convert2dict(), **kwargs) 

297 return data.convert2dict(), fig, ax 

298 

299 return data 

300 

301 

302def probe_rms_plot(rms_amps, chn_coords, chn_inds, cmap='plasma', band='AP', 

303 display=False, pad=True, x_offset=1, **kwargs): 

304 """ 

305 Prepare data for 2D probe plot of RMS along depth of probe 

306 

307 :param rms_amps: 

308 :param chn_coords: 

309 :param chn_inds: 

310 :param cmap: 

311 :param band: 

312 :param display: 

313 :param pad: whether to add nans around the individual image plots. For matplotlib use pad=True, 

314 for pyqtgraph use pad=False 

315 :param x_offset: Distance between the channel banks in x direction 

316 :return: ProbePlot object, if display=True also returns matplotlib fig and ax objects 

317 """ 

318 

319 rms = (np.mean(rms_amps, axis=0)[chn_inds]) * 1e6 

320 

321 data_bank, x_bank, y_bank = arrange_channels2banks(rms, chn_coords, depth=None, 

322 pad=pad, x_offset=x_offset) 

323 data = ProbePlot(data_bank, x=x_bank, y=y_bank, cmap=cmap) 

324 data.set_labels(ylabel='Distance from probe tip (um)', clabel=f'{band} RMS (uV)') 

325 clim = np.nanquantile(np.concatenate([np.squeeze(np.ravel(d)) for d in data_bank]).ravel(), 

326 [0.1, 0.9]) 

327 data.set_clim(clim) 

328 

329 if display: 

330 ax, fig = plot_probe(data.convert2dict(), **kwargs) 

331 return data.convert2dict(), fig, ax 

332 

333 return data 

334 

335 

336def line_fr_plot(spike_depths, spike_times, chn_coords, d_bin=10, display=False, title=None, **kwargs): 

337 """ 

338 Prepare data for 1D line plot of average firing rate across depth 

339 

340 :param spike_depths: 

341 :param spike_times: 

342 :param chn_coords: 

343 :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D) 

344 :param display: 

345 :return: 

346 """ 

347 

348 title = title or 'Avg Firing Rate' 

349 t_bin = np.max(spike_times) 

350 n, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin, 

351 ylim=[0, np.max(chn_coords[:, 1])]) 

352 mean_fr = n[:, 0] / t_bin 

353 

354 data = LinePlot(x=mean_fr, y=y) 

355 data.set_xlim((0, np.max(mean_fr))) 

356 data.set_labels(title=title, xlabel='Firing Rate (Hz)', 

357 ylabel='Distance from probe tip (um)') 

358 

359 if display: 

360 ax, fig = plot_line(data.convert2dict(), **kwargs) 

361 return data.convert2dict(), fig, ax 

362 

363 return data 

364 

365 

366def line_amp_plot(spike_amps, spike_depths, spike_times, chn_coords, d_bin=10, display=False, title=None, **kwargs): 

367 """ 

368 Prepare data for 1D line plot of average firing rate across depth 

369 :param spike_amps: 

370 :param spike_depths: 

371 :param spike_times: 

372 :param chn_coords: 

373 :param d_bin: depth bin to average across (see also brainbox.processing.bincount2D) 

374 :param display: 

375 :return: 

376 """ 

377 title = title or 'Avg Amplitude' 

378 t_bin = np.max(spike_times) 

379 n, _, _ = bincount2D(spike_times, spike_depths, t_bin, d_bin, 

380 ylim=[0, np.max(chn_coords[:, 1])]) 

381 amp, x, y = bincount2D(spike_times, spike_depths, t_bin, d_bin, 

382 ylim=[0, np.max(chn_coords[:, 1])], weights=spike_amps) 

383 

384 mean_amp = np.divide(amp[:, 0], n[:, 0]) * 1e6 

385 mean_amp[np.isnan(mean_amp)] = 0 

386 remove_bins = np.where(n[:, 0] < 50)[0] 

387 mean_amp[remove_bins] = 0 

388 

389 data = LinePlot(x=mean_amp, y=y) 

390 data.set_xlim((0, np.max(mean_amp))) 

391 data.set_labels(title=title, xlabel='Amplitude (uV)', 

392 ylabel='Distance from probe tip (um)') 

393 if display: 

394 ax, fig = plot_line(data.convert2dict(), **kwargs) 

395 return data.convert2dict(), fig, ax 

396 return data 

397 

398 

399def plot_brain_regions(channel_ids, channel_depths=None, brain_regions=None, display=True, ax=None, 

400 title=None, label='left', **kwargs): 

401 """ 

402 Plot brain regions along probe, if channel depths is provided will plot along depth otherwise along channel idx 

403 :param channel_ids: atlas ids for each channel 

404 :param channel_depths: depth along probe for each channel 

405 :param brain_regions: BrainRegions object 

406 :param display: whether to output plot 

407 :param ax: axis to plot on 

408 :param title: title for plot 

409 :param kwargs: additional keyword arguments for bar plot 

410 :return: 

411 """ 

412 

413 if channel_depths is not None: 1a

414 assert channel_ids.shape[0] == channel_depths.shape[0] 1a

415 else: 

416 channel_depths = np.arange(channel_ids.shape[0]) 

417 

418 br = brain_regions or BrainRegions() 1a

419 

420 region_info = br.get(channel_ids) 1a

421 boundaries = np.where(np.diff(region_info.id) != 0)[0] 1a

422 boundaries = np.r_[0, boundaries, region_info.id.shape[0] - 1] 1a

423 

424 regions = np.c_[boundaries[0:-1], boundaries[1:]] 1a

425 if channel_depths is not None: 1a

426 regions = channel_depths[regions] 1a

427 region_labels = np.c_[np.mean(regions, axis=1), region_info.acronym[boundaries[1:]]] 1a

428 region_colours = region_info.rgb[boundaries[1:]] 1a

429 

430 if display: 1a

431 if ax is None: 1a

432 fig, ax = plt.subplots() 

433 else: 

434 fig = ax.get_figure() 1a

435 

436 for reg, col in zip(regions, region_colours): 1a

437 height = np.abs(reg[1] - reg[0]) 1a

438 bar_kwargs = dict(edgecolor='w', width=1) 1a

439 bar_kwargs.update(**kwargs) 1a

440 color = col / 255 1a

441 ax.bar(x=0.5, height=height, color=color, bottom=reg[0], **kwargs) 1a

442 if label == 'right': 1a

443 ax.yaxis.tick_right() 

444 ax.set_yticks(region_labels[:, 0].astype(int)) 1a

445 ax.yaxis.set_tick_params(labelsize=8) 1a

446 ax.set_ylim(np.nanmin(channel_depths), np.nanmax(channel_depths)) 1a

447 ax.get_xaxis().set_visible(False) 1a

448 ax.set_yticklabels(region_labels[:, 1]) 1a

449 if label == 'right': 1a

450 ax.yaxis.tick_right() 

451 ax.spines['left'].set_visible(False) 

452 else: 

453 ax.spines['right'].set_visible(False) 1a

454 ax.spines['top'].set_visible(False) 1a

455 ax.spines['bottom'].set_visible(False) 1a

456 if title: 1a

457 ax.set_title(title) 1a

458 

459 return fig, ax 1a

460 else: 

461 return regions, region_labels, region_colours 

462 

463 

464def plot_cdf(spike_amps, spike_depths, spike_times, n_amp_bins=10, d_bin=40, amp_range=None, d_range=None, 

465 display=False, cmap='hot', ax=None): 

466 """ 

467 Plot cumulative amplitude of spikes across depth 

468 :param spike_amps: 

469 :param spike_depths: 

470 :param spike_times: 

471 :param n_amp_bins: number of amplitude bins to use 

472 :param d_bin: the value of the depth bins in um (default is 40 um) 

473 :param amp_range: amp range to use [amp_min, amp_max], if not given automatically computed from spike_amps 

474 :param d_range: depth range to use, by default [0, 3840] 

475 :param display: whether or not to display plot 

476 :param cmap: 

477 :return: 

478 """ 

479 

480 amp_range = amp_range or np.quantile(spike_amps, (0, 0.9)) 

481 amp_bins = np.linspace(amp_range[0], amp_range[1], n_amp_bins) 

482 d_range = d_range or [0, 3840] 

483 depth_bins = np.arange(d_range[0], d_range[1] + d_bin, d_bin) 

484 t_bin = np.max(spike_times) 

485 

486 def histc(x, bins): 

487 map_to_bins = np.digitize(x, bins) # Get indices of the bins to which each value in input array belongs. 

488 res = np.zeros(bins.shape) 

489 

490 for el in map_to_bins: 

491 res[el - 1] += 1 # Increment appropriate bin. 

492 return res 

493 

494 cdfs = np.empty((len(depth_bins) - 1, n_amp_bins)) 

495 for d in range(len(depth_bins) - 1): 

496 spikes = np.bitwise_and(spike_depths > depth_bins[d], spike_depths <= depth_bins[d + 1]) 

497 h = histc(spike_amps[spikes], amp_bins) / t_bin 

498 hcsum = np.cumsum(h[::-1]) 

499 cdfs[d, :] = hcsum[::-1] 

500 

501 cdfs[cdfs == 0] = np.nan 

502 

503 data = ImagePlot(cdfs.T, x=amp_bins * 1e6, y=depth_bins[:-1], cmap=cmap) 

504 data.set_labels(title='Cumulative Amplitude', xlabel='Spike amplitude (uV)', 

505 ylabel='Distance from probe tip (um)', clabel='Firing Rate (Hz)') 

506 

507 if display: 

508 ax, fig = plot_image(data.convert2dict(), fig_kwargs={'figsize': [3, 7]}, ax=ax) 

509 return data.convert2dict(), fig, ax 

510 

511 return data 

512 

513 

514def image_raw_data(raw, fs, chn_coords=None, cmap='bone', title=None, display=False, gain=-90, **kwargs): 

515 

516 def gain2level(gain): 

517 return 10 ** (gain / 20) * 4 * np.array([-1, 1]) 

518 

519 ylabel = 'Channel index' if chn_coords is None else 'Distance from probe tip (um)' 

520 title = title or 'Raw data' 

521 

522 y = np.arange(raw.shape[1]) if chn_coords is None else chn_coords[:, 1] 

523 

524 x = np.array([0, raw.shape[0] - 1]) / fs * 1e3 

525 

526 data = ImagePlot(raw, y=y, cmap=cmap) 

527 data.set_labels(title=title, xlabel='Time (ms)', 

528 ylabel=ylabel, clabel='Power (uV)') 

529 clim = gain2level(gain) 

530 data.set_clim(clim=clim) 

531 data.set_xlim(xlim=x) 

532 data.set_ylim() 

533 

534 if display: 

535 ax, fig = plot_image(data.convert2dict(), **kwargs) 

536 return data.convert2dict(), fig, ax 

537 

538 return data