Coverage for brainbox/task/passive.py: 94%

101 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1""" 

2Functions dealing with passive task 

3""" 

4import numpy as np 

5from iblutil.numerical import bincount2D 

6from scipy.linalg import svd 

7 

8 

9def get_on_off_times_and_positions(rf_map): 

10 """ 

11 

12 Prepares passive receptive field mapping into format for analysis 

13 Parameters 

14 ---------- 

15 rf_map: output from brainbox.io.one.load_passive_rfmap 

16 

17 Returns 

18 ------- 

19 rf_map_times: time of each receptive field map frame np.array(len(stim_frames) 

20 rf_map_pos: unique position of each pixel on screen np.array(len(x_pos), len(y_pos)) 

21 rf_stim_frames: for each pixel on screen stores array of stimulus frames where stim onset 

22 occurred. For both white squares 'on' and black squares 'off' 

23 

24 """ 

25 

26 rf_map_times = rf_map['times'] 1a

27 rf_map_frames = rf_map['frames'].astype('float') 1a

28 

29 gray = np.median(rf_map_frames) 1a

30 

31 x_bin = rf_map_frames.shape[1] 1a

32 y_bin = rf_map_frames.shape[2] 1a

33 

34 stim_on_frames = np.zeros((x_bin * y_bin, 1), dtype=np.ndarray) 1a

35 stim_off_frames = np.zeros((x_bin * y_bin, 1), dtype=np.ndarray) 1a

36 rf_map_pos = np.zeros((x_bin * y_bin, 2), dtype=int) 1a

37 

38 i = 0 1a

39 for x_pos in np.arange(x_bin): 1a

40 for y_pos in np.arange(y_bin): 1a

41 

42 pixel_val = rf_map_frames[:, x_pos, y_pos] - gray 1a

43 pixel_non_grey = np.where(pixel_val != 0)[0] 1a

44 # Find cases where the frame before was gray (i.e when the stim came on) 

45 frame_change = np.where(rf_map_frames[pixel_non_grey - 1, x_pos, y_pos] == gray)[0] 1a

46 

47 stim_pos = pixel_non_grey[frame_change] 1a

48 

49 # On stimulus, white squares 

50 on_pix = np.where(pixel_val[stim_pos] > 0)[0] 1a

51 stim_on = stim_pos[on_pix] 1a

52 stim_on_frames[i, 0] = stim_on 1a

53 

54 off_pix = np.where(pixel_val[stim_pos] < 0)[0] 1a

55 stim_off = stim_pos[off_pix] 1a

56 stim_off_frames[i, 0] = stim_off 1a

57 

58 rf_map_pos[i, :] = [x_pos, y_pos] 1a

59 i += 1 1a

60 

61 rf_stim_frames = {} 1a

62 rf_stim_frames['on'] = stim_on_frames 1a

63 rf_stim_frames['off'] = stim_off_frames 1a

64 

65 return rf_map_times, rf_map_pos, rf_stim_frames 1a

66 

67 

68def get_rf_map_over_depth(rf_map_times, rf_map_pos, rf_stim_frames, spike_times, spike_depths, 

69 t_bin=0.01, d_bin=80, pre_stim=0.05, post_stim=1.5, y_lim=[0, 3840], 

70 x_lim=None): 

71 """ 

72 Compute receptive field map for each stimulus onset binned across depth 

73 Parameters 

74 ---------- 

75 rf_map_times 

76 rf_map_pos 

77 rf_stim_frames 

78 spike_times: array of spike times 

79 spike_depths: array of spike depths along probe 

80 t_bin: bin size along time dimension 

81 d_bin: bin size along depth dimension 

82 pre_stim: time period before rf map stim onset to epoch around 

83 post_stim: time period after rf map onset to epoch around 

84 y_lim: values to limit to in depth direction 

85 x_lim: values to limit in time direction 

86 

87 Returns 

88 ------- 

89 rfmap: receptive field map for 'on' 'off' stimuli. 

90 Each rfmap has shape (depths, x_pos, y_pos, epoch_window) 

91 depths: depths between which receptive field map has been computed 

92 """ 

93 

94 binned_array, times, depths = bincount2D(spike_times, spike_depths, t_bin, d_bin, 1a

95 ylim=y_lim, xlim=x_lim) 

96 

97 x_bin = len(np.unique(rf_map_pos[:, 0])) 1a

98 y_bin = len(np.unique(rf_map_pos[:, 1])) 1a

99 n_bins = int((pre_stim + post_stim) / t_bin) 1a

100 

101 rf_map = {} 1a

102 

103 for stim_type, stims in rf_stim_frames.items(): 1a

104 _rf_map = np.zeros(shape=(depths.shape[0], x_bin, y_bin, n_bins)) 1a

105 for pos, stim_frame in zip(rf_map_pos, stims): 1a

106 

107 x_pos = pos[0] 1a

108 y_pos = pos[1] 1a

109 

110 # Case where there is no stimulus at this position 

111 if len(stim_frame[0]) == 0: 1a

112 _rf_map[:, x_pos, y_pos, :] = np.zeros((depths.shape[0], n_bins)) 1a

113 continue 1a

114 

115 stim_on_times = rf_map_times[stim_frame[0]] 1a

116 stim_intervals = np.c_[stim_on_times - pre_stim, stim_on_times + post_stim] 1a

117 

118 out_intervals = stim_intervals[:, 1] > times[-1] 1a

119 idx_intervals = np.searchsorted(times, stim_intervals)[np.invert(out_intervals)] 1a

120 

121 # Case when no spikes during the passive period 

122 if idx_intervals.shape[0] == 0: 1a

123 avg_stim_trials = np.zeros((depths.shape[0], n_bins)) 

124 else: 

125 stim_trials = np.zeros((depths.shape[0], n_bins, idx_intervals.shape[0])) 1a

126 for i, on in enumerate(idx_intervals): 1a

127 stim_trials[:, :, i] = binned_array[:, on[0]:on[1]] 1a

128 avg_stim_trials = np.mean(stim_trials, axis=2) 1a

129 

130 _rf_map[:, x_pos, y_pos, :] = avg_stim_trials 1a

131 

132 rf_map[stim_type] = _rf_map 1a

133 

134 return rf_map, depths 1a

135 

136 

137def get_svd_map(rf_map): 

138 """ 

139 Perform SVD on the spatiotemporal rf_map and return the first spatial components 

140 Parameters 

141 ---------- 

142 rf_map 

143 

144 Returns 

145 ------- 

146 rf_svd: First spatial component of rf map for 'on' 'off' stimuli. 

147 Each dict has shape (depths, x_pos, y_pos) 

148 """ 

149 

150 rf_svd = {} 1a

151 for stim_type, stims in rf_map.items(): 1a

152 svd_stim = [] 1a

153 for dep in stims: 1a

154 x_pix, y_pix, n_bins = dep.shape 1a

155 sub_reshaped = np.reshape(dep, (y_pix * x_pix, n_bins)) 1a

156 bsl = np.mean(sub_reshaped[:, 0]) 1a

157 

158 u, s, v = svd(sub_reshaped - bsl) 1a

159 sign = -1 if np.median(v[0, :]) < 0 else 1 1a

160 rfs = sign * np.reshape(u[:, 0], (y_pix, x_pix)) 1a

161 rfs *= s[0] 1a

162 

163 svd_stim.append(rfs) 1a

164 

165 rf_svd[stim_type] = svd_stim 1a

166 

167 return rf_svd 1a

168 

169 

170def get_stim_aligned_activity(stim_events, spike_times, spike_depths, z_score_flag=True, d_bin=20, 

171 t_bin=0.01, pre_stim=0.4, post_stim=1, base_stim=1, 

172 y_lim=[0, 3840], x_lim=None): 

173 """ 

174 

175 Parameters 

176 ---------- 

177 stim_events: dict of different stim events. Each key contains time of stimulus onset 

178 spike_times: array of spike times 

179 spike_depths: array of spike depths along probe 

180 z_score_flag: whether to return values as z_score of firing rate 

181 T_BIN: bin size along time dimension 

182 D_BIN: bin size along depth dimension 

183 pre_stim: time period before rf map stim onset to epoch around 

184 post_stim: time period after rf map onset to epoch around 

185 base_stim: time period before rf map stim to use as baseline for z_score correction 

186 y_lim: values to limit to in depth direction 

187 x_lim: values to limit in time direction 

188 

189 Returns 

190 ------- 

191 stim_activity: stimulus aligned activity for each stimulus type, returned as z_score of firing 

192 rate 

193 """ 

194 

195 binned_array, times, depths = bincount2D(spike_times, spike_depths, t_bin, d_bin, 1b

196 ylim=y_lim, xlim=x_lim) 

197 n_bins = int((pre_stim + post_stim) / t_bin) 1b

198 n_bins_base = int(np.ceil((base_stim - pre_stim) / t_bin)) 1b

199 

200 stim_activity = {} 1b

201 for stim_type, stim_times in stim_events.items(): 1b

202 

203 # Get rid of any nan values 

204 stim_times = stim_times[~np.isnan(stim_times)] 1b

205 stim_intervals = stim_times - pre_stim 1b

206 base_intervals = stim_times - base_stim 1b

207 out_intervals = stim_intervals > times[-1] 1b

208 

209 idx_stim = np.searchsorted(times, stim_intervals, side='right')[np.invert(out_intervals)] 1b

210 idx_base = np.searchsorted(times, base_intervals, side='right')[np.invert(out_intervals)] 1b

211 idx_stim = np.c_[idx_stim, idx_stim + n_bins] 1b

212 idx_base = np.c_[idx_base, idx_base + n_bins_base] 1b

213 

214 stim_trials = np.zeros((depths.shape[0], n_bins, idx_stim.shape[0])) 1b

215 noise_trials = np.zeros((depths.shape[0], n_bins_base, idx_stim.shape[0])) 1b

216 for i, (st, ba) in enumerate(zip(idx_stim, idx_base)): 1b

217 stim_trials[:, :, i] = binned_array[:, st[0]:st[1]] 1b

218 noise_trials[:, :, i] = binned_array[:, ba[0]:ba[1]] 1b

219 

220 # Average across trials 

221 avg_stim_trials = np.mean(stim_trials, axis=2) 1b

222 if z_score_flag: 1b

223 # Average across trials and time 

224 avg_base_trials = np.mean(np.mean(noise_trials, axis=2), axis=1)[:, np.newaxis] 

225 std_base_trials = np.std(np.mean(noise_trials, axis=2), axis=1)[:, np.newaxis] 

226 z_score = (avg_stim_trials - avg_base_trials) / std_base_trials 

227 z_score[np.isnan(z_score)] = 0 

228 avg_stim_trials = z_score 

229 

230 stim_activity[stim_type] = avg_stim_trials 1b

231 

232 return stim_activity 1b