Coverage for brainbox/atlas.py: 0%

81 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-11 11:13 +0100

1""" 

2Functions which map metrics to the Allen atlas. 

3 

4Code by G. Meijer 

5""" 

6 

7import numpy as np 

8import seaborn as sns 

9import matplotlib.pyplot as plt 

10from iblatlas import atlas 

11 

12 

13def _label2values(imlabel, fill_values, ba): 

14 """ 

15 Fills a slice from the label volume with values to display 

16 :param imlabel: 2D np-array containing label ids (slice of the label volume) 

17 :param fill_values: 1D np-array containing values to fill into the slice 

18 :return: 2D np-array filled with values 

19 """ 

20 im_unique, ilabels, iim = np.unique(imlabel, return_index=True, return_inverse=True) 

21 _, ir_unique, _ = np.intersect1d(ba.regions.id, im_unique, return_indices=True) 

22 im = np.squeeze(np.reshape(fill_values[ir_unique[iim]], (*imlabel.shape, 1))) 

23 return im 

24 

25 

26def plot_atlas(regions, values, ML=-1, AP=0, DV=-1, hemisphere='left', color_palette='Reds', 

27 minmax=None, axs=None, custom_region_list=None): 

28 """ 

29 Plot a sagittal, coronal and horizontal slice of the Allen atlas with regions colored in 

30 according to any value that the user specifies. 

31 

32 Parameters 

33 ---------- 

34 regions : 1D array 

35 Array of strings with the acronyms of brain regions (in Allen convention) that should be 

36 filled with color 

37 values : 1D array 

38 Array of values that correspond to the brain region acronyms 

39 ML, AP, DV : float 

40 The coordinates of the slices in mm 

41 hemisphere : string 

42 Which hemisphere to color, options are 'left' (default), 'right', 'both' 

43 color_palette : any input that can be interpreted by sns.color_palette 

44 The color palette of the plot 

45 minmax : 2 element array 

46 The min and max of the color map, if None it uses the min and max of values 

47 axs : 3 element list of axis 

48 A list of the three axis in which to plot the three slices 

49 custom_region_list : 1D array with shape the same as ba.regions.acronym.shape 

50 Input any custom list of acronyms that replaces the default list of acronyms 

51 found in ba.regions.acronym. For example if you want to merge certain regions you can 

52 give them the same name in the custom_region_list 

53 """ 

54 

55 # Import Allen atlas 

56 ba = atlas.AllenAtlas(25) 

57 

58 # Check input 

59 assert regions.shape == values.shape 

60 if minmax is not None: 

61 assert len(minmax) == 2 

62 if axs is not None: 

63 assert len(axs) == 3 

64 if custom_region_list is not None: 

65 assert custom_region_list.shape == ba.regions.acronym.shape 

66 

67 # Get region boundaries volume 

68 boundaries = np.diff(ba.label, axis=0, append=0) 

69 boundaries = boundaries + np.diff(ba.label, axis=1, append=0) 

70 boundaries = boundaries + np.diff(ba.label, axis=2, append=0) 

71 boundaries[boundaries != 0] = 1 

72 

73 # Get all brain region names, use custom list if inputted 

74 if custom_region_list is None: 

75 all_regions = ba.regions.acronym 

76 else: 

77 all_regions = custom_region_list 

78 

79 # Set values outside colormap bounds 

80 if minmax is not None: 

81 values[values < minmax[0] + np.abs(minmax[0] / 1000)] = (minmax[0] 

82 + np.abs(minmax[0] / 1000)) 

83 values[values > minmax[1] - np.abs(minmax[1] / 1000)] = (minmax[1] 

84 - np.abs(minmax[0] / 1000)) 

85 

86 # Add values to brain region list 

87 region_values = np.ones(ba.regions.acronym.shape) * (np.min(values) - (np.max(values) + 1)) 

88 for i, region in enumerate(regions): 

89 region_values[all_regions == region] = values[i] 

90 

91 # Set 'void' to default white 

92 region_values[0] = np.min(values) - (np.max(values) + 1) 

93 

94 # Get slices with fill values 

95 slice_sag = ba.slice(ML / 1000, axis=0, volume=ba.label) # saggital 

96 slice_sag = _label2values(slice_sag, region_values, ba) 

97 bound_sag = ba.slice(ML / 1000, axis=0, volume=boundaries) 

98 slice_cor = ba.slice(AP / 1000, axis=1, volume=ba.label) # coronal 

99 slice_cor = _label2values(slice_cor, region_values, ba) 

100 bound_cor = ba.slice(AP / 1000, axis=1, volume=boundaries) 

101 slice_hor = ba.slice(DV / 1000, axis=2, volume=ba.label) # horizontal 

102 slice_hor = _label2values(slice_hor, region_values, ba) 

103 bound_hor = ba.slice(DV / 1000, axis=2, volume=boundaries) 

104 

105 # Only color specified hemisphere 

106 if hemisphere == 'left': 

107 slice_cor[:int(slice_cor.shape[0] / 2), :] = np.min(values) - (np.max(values) + 1) 

108 slice_hor[:, int(slice_cor.shape[0] / 2):] = np.min(values) - (np.max(values) + 1) 

109 elif hemisphere == 'right': 

110 slice_cor[int(slice_cor.shape[0] / 2):, :] = np.min(values) - (np.max(values) + 1) 

111 slice_hor[:, :int(slice_cor.shape[0] / 2)] = np.min(values) - (np.max(values) + 1) 

112 if ((hemisphere == 'left') & (ML > 0)) or ((hemisphere == 'right') & (ML < 0)): 

113 slice_sag[:] = np.min(values) - (np.max(values) + 1) 

114 

115 # Add boundaries to slices outside of the fill value region and set to grey 

116 if minmax is None: 

117 slice_sag[bound_sag == 1] = np.max(values) + 1 

118 slice_cor[bound_cor == 1] = np.max(values) + 1 

119 slice_hor[bound_hor == 1] = np.max(values) + 1 

120 else: 

121 slice_sag[bound_sag == 1] = minmax[1] + 1 

122 slice_cor[bound_cor == 1] = minmax[1] + 1 

123 slice_hor[bound_hor == 1] = minmax[1] + 1 

124 

125 # Construct color map 

126 color_map = sns.color_palette(color_palette, 1000) 

127 color_map.append((0.8, 0.8, 0.8)) # color of the boundaries between regions 

128 color_map.insert(0, (1, 1, 1)) # color of the background and regions without a value 

129 

130 # Get color scale 

131 if minmax is None: 

132 cmin = np.min(values) 

133 cmax = np.max(values) 

134 else: 

135 cmin = minmax[0] 

136 cmax = minmax[1] 

137 

138 # Plot 

139 if axs is None: 

140 fig, axs = plt.subplots(1, 3, figsize=(16, 4)) 

141 

142 # Saggital 

143 sns.heatmap(np.rot90(slice_sag, 3), cmap=color_map, cbar=True, vmin=cmin, vmax=cmax, ax=axs[0]) 

144 axs[0].set(title='ML: %.1f mm' % ML) 

145 plt.axis('off') 

146 axs[0].get_xaxis().set_visible(False) 

147 axs[0].get_yaxis().set_visible(False) 

148 

149 # Coronal 

150 sns.heatmap(np.rot90(slice_cor, 3), cmap=color_map, cbar=True, vmin=cmin, vmax=cmax, ax=axs[1]) 

151 axs[1].set(title='AP: %.1f mm' % AP) 

152 plt.axis('off') 

153 axs[1].get_xaxis().set_visible(False) 

154 axs[1].get_yaxis().set_visible(False) 

155 

156 # Horizontal 

157 sns.heatmap(np.rot90(slice_hor, 3), cmap=color_map, cbar=True, vmin=cmin, vmax=cmax, ax=axs[2]) 

158 axs[2].set(title='DV: %.1f mm' % DV) 

159 plt.axis('off') 

160 axs[2].get_xaxis().set_visible(False) 

161 axs[2].get_yaxis().set_visible(False)