Coverage for brainbox/atlas.py: 0%
81 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:13 +0100
1"""
2Functions which map metrics to the Allen atlas.
4Code by G. Meijer
5"""
7import numpy as np
8import seaborn as sns
9import matplotlib.pyplot as plt
10from iblatlas import atlas
13def _label2values(imlabel, fill_values, ba):
14 """
15 Fills a slice from the label volume with values to display
16 :param imlabel: 2D np-array containing label ids (slice of the label volume)
17 :param fill_values: 1D np-array containing values to fill into the slice
18 :return: 2D np-array filled with values
19 """
20 im_unique, ilabels, iim = np.unique(imlabel, return_index=True, return_inverse=True)
21 _, ir_unique, _ = np.intersect1d(ba.regions.id, im_unique, return_indices=True)
22 im = np.squeeze(np.reshape(fill_values[ir_unique[iim]], (*imlabel.shape, 1)))
23 return im
26def plot_atlas(regions, values, ML=-1, AP=0, DV=-1, hemisphere='left', color_palette='Reds',
27 minmax=None, axs=None, custom_region_list=None):
28 """
29 Plot a sagittal, coronal and horizontal slice of the Allen atlas with regions colored in
30 according to any value that the user specifies.
32 Parameters
33 ----------
34 regions : 1D array
35 Array of strings with the acronyms of brain regions (in Allen convention) that should be
36 filled with color
37 values : 1D array
38 Array of values that correspond to the brain region acronyms
39 ML, AP, DV : float
40 The coordinates of the slices in mm
41 hemisphere : string
42 Which hemisphere to color, options are 'left' (default), 'right', 'both'
43 color_palette : any input that can be interpreted by sns.color_palette
44 The color palette of the plot
45 minmax : 2 element array
46 The min and max of the color map, if None it uses the min and max of values
47 axs : 3 element list of axis
48 A list of the three axis in which to plot the three slices
49 custom_region_list : 1D array with shape the same as ba.regions.acronym.shape
50 Input any custom list of acronyms that replaces the default list of acronyms
51 found in ba.regions.acronym. For example if you want to merge certain regions you can
52 give them the same name in the custom_region_list
53 """
55 # Import Allen atlas
56 ba = atlas.AllenAtlas(25)
58 # Check input
59 assert regions.shape == values.shape
60 if minmax is not None:
61 assert len(minmax) == 2
62 if axs is not None:
63 assert len(axs) == 3
64 if custom_region_list is not None:
65 assert custom_region_list.shape == ba.regions.acronym.shape
67 # Get region boundaries volume
68 boundaries = np.diff(ba.label, axis=0, append=0)
69 boundaries = boundaries + np.diff(ba.label, axis=1, append=0)
70 boundaries = boundaries + np.diff(ba.label, axis=2, append=0)
71 boundaries[boundaries != 0] = 1
73 # Get all brain region names, use custom list if inputted
74 if custom_region_list is None:
75 all_regions = ba.regions.acronym
76 else:
77 all_regions = custom_region_list
79 # Set values outside colormap bounds
80 if minmax is not None:
81 values[values < minmax[0] + np.abs(minmax[0] / 1000)] = (minmax[0]
82 + np.abs(minmax[0] / 1000))
83 values[values > minmax[1] - np.abs(minmax[1] / 1000)] = (minmax[1]
84 - np.abs(minmax[0] / 1000))
86 # Add values to brain region list
87 region_values = np.ones(ba.regions.acronym.shape) * (np.min(values) - (np.max(values) + 1))
88 for i, region in enumerate(regions):
89 region_values[all_regions == region] = values[i]
91 # Set 'void' to default white
92 region_values[0] = np.min(values) - (np.max(values) + 1)
94 # Get slices with fill values
95 slice_sag = ba.slice(ML / 1000, axis=0, volume=ba.label) # saggital
96 slice_sag = _label2values(slice_sag, region_values, ba)
97 bound_sag = ba.slice(ML / 1000, axis=0, volume=boundaries)
98 slice_cor = ba.slice(AP / 1000, axis=1, volume=ba.label) # coronal
99 slice_cor = _label2values(slice_cor, region_values, ba)
100 bound_cor = ba.slice(AP / 1000, axis=1, volume=boundaries)
101 slice_hor = ba.slice(DV / 1000, axis=2, volume=ba.label) # horizontal
102 slice_hor = _label2values(slice_hor, region_values, ba)
103 bound_hor = ba.slice(DV / 1000, axis=2, volume=boundaries)
105 # Only color specified hemisphere
106 if hemisphere == 'left':
107 slice_cor[:int(slice_cor.shape[0] / 2), :] = np.min(values) - (np.max(values) + 1)
108 slice_hor[:, int(slice_cor.shape[0] / 2):] = np.min(values) - (np.max(values) + 1)
109 elif hemisphere == 'right':
110 slice_cor[int(slice_cor.shape[0] / 2):, :] = np.min(values) - (np.max(values) + 1)
111 slice_hor[:, :int(slice_cor.shape[0] / 2)] = np.min(values) - (np.max(values) + 1)
112 if ((hemisphere == 'left') & (ML > 0)) or ((hemisphere == 'right') & (ML < 0)):
113 slice_sag[:] = np.min(values) - (np.max(values) + 1)
115 # Add boundaries to slices outside of the fill value region and set to grey
116 if minmax is None:
117 slice_sag[bound_sag == 1] = np.max(values) + 1
118 slice_cor[bound_cor == 1] = np.max(values) + 1
119 slice_hor[bound_hor == 1] = np.max(values) + 1
120 else:
121 slice_sag[bound_sag == 1] = minmax[1] + 1
122 slice_cor[bound_cor == 1] = minmax[1] + 1
123 slice_hor[bound_hor == 1] = minmax[1] + 1
125 # Construct color map
126 color_map = sns.color_palette(color_palette, 1000)
127 color_map.append((0.8, 0.8, 0.8)) # color of the boundaries between regions
128 color_map.insert(0, (1, 1, 1)) # color of the background and regions without a value
130 # Get color scale
131 if minmax is None:
132 cmin = np.min(values)
133 cmax = np.max(values)
134 else:
135 cmin = minmax[0]
136 cmax = minmax[1]
138 # Plot
139 if axs is None:
140 fig, axs = plt.subplots(1, 3, figsize=(16, 4))
142 # Saggital
143 sns.heatmap(np.rot90(slice_sag, 3), cmap=color_map, cbar=True, vmin=cmin, vmax=cmax, ax=axs[0])
144 axs[0].set(title='ML: %.1f mm' % ML)
145 plt.axis('off')
146 axs[0].get_xaxis().set_visible(False)
147 axs[0].get_yaxis().set_visible(False)
149 # Coronal
150 sns.heatmap(np.rot90(slice_cor, 3), cmap=color_map, cbar=True, vmin=cmin, vmax=cmax, ax=axs[1])
151 axs[1].set(title='AP: %.1f mm' % AP)
152 plt.axis('off')
153 axs[1].get_xaxis().set_visible(False)
154 axs[1].get_yaxis().set_visible(False)
156 # Horizontal
157 sns.heatmap(np.rot90(slice_hor, 3), cmap=color_map, cbar=True, vmin=cmin, vmax=cmax, ax=axs[2])
158 axs[2].set(title='DV: %.1f mm' % DV)
159 plt.axis('off')
160 axs[2].get_xaxis().set_visible(False)
161 axs[2].get_yaxis().set_visible(False)