Coverage for brainbox/spike_features.py: 0%

25 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1''' 

2Functions that compute spike features from spike waveforms. 

3''' 

4 

5import numpy as np 

6from brainbox.io.spikeglx import extract_waveforms 

7 

8 

9def depth(ephys_file, spks_b, clstrs_b, chnls_b, tmplts_b, unit, n_ch=12, n_ch_probe=385, sr=30000, 

10 dtype='int16', car=False): 

11 ''' 

12 Gets `n_ch` channels around a unit's channel of max amplitude, extracts all unit spike 

13 waveforms from binary datafile for these channels, and for each spike, computes the dot 

14 products of waveform by unit template for those channels, and computes center-of-mass of these 

15 dot products to get spike depth estimates. 

16 

17 Parameters 

18 ---------- 

19 ephys_file : string 

20 The file path to the binary ephys data. 

21 spks_b : bunch 

22 A spikes bunch containing fields with spike information (e.g. cluster IDs, times, features, 

23 etc.) for all spikes. 

24 clstrs_b : bunch 

25 A clusters bunch containing fields with cluster information (e.g. amp, ch of max amp, depth 

26 of ch of max amp, etc.) for all clusters. 

27 chnls_b : bunch 

28 A channels bunch containing fields with channel information (e.g. coordinates, indices, 

29 etc.) for all probe channels. 

30 tmplts_b : bunch 

31 A unit templates bunch containing fields with unit template information (e.g. template 

32 waveforms, etc.) for all unit templates. 

33 unit : numeric 

34 The unit for which to return the spikes depths. 

35 n_ch : int (optional) 

36 The number of channels to sample around the channel of max amplitude to compute the depths. 

37 sr : int (optional) 

38 The sampling rate (hz) that the ephys data was acquired at. 

39 n_ch_probe : int (optional) 

40 The number of channels of the recording. 

41 dtype: str (optional) 

42 The datatype represented by the bytes in `ephys_file`. 

43 car: bool (optional) 

44 A flag to perform common-average-referencing before extracting waveforms. 

45 

46 Returns 

47 ------- 

48 d : ndarray 

49 The estimated spike depths for all spikes in `unit`. 

50 

51 See Also 

52 -------- 

53 io.extract_waveforms 

54 

55 Examples 

56 -------- 

57 1) Get the spike depths for unit 1. 

58 >>> import numpy as np 

59 >>> import brainbox as bb 

60 >>> import alf.io as aio 

61 >>> import ibllib.ephys.spikes as e_spks 

62 (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): 

63 >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) 

64 # Get the necessary alf objects from an alf directory. 

65 >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') 

66 >>> clstrs_b = aio.load_object(path_to_alf_out, 'clusters') 

67 >>> chnls_b = aio.load_object(path_to_alf_out, 'channels') 

68 >>> tmplts_b = aio.load_object(path_to_alf_out, 'templates') 

69 # Compute spike depths. 

70 >>> unit1_depths = bb.spike_features.depth(path_to_ephys_file, spks_b, clstrs_b, chnls_b, 

71 tmplts_b, unit=1) 

72 ''' 

73 

74 # Set constants: # 

75 n_c_ch = n_ch // 2 # number of close channels to take on either side of max channel 

76 

77 # Get unit waveforms: # 

78 # Get unit timestamps. 

79 unit_spk_indxs = np.where(spks_b['clusters'] == unit)[0] 

80 ts = spks_b['times'][unit_spk_indxs] 

81 # Get `n_close_ch` channels around channel of max amplitude. 

82 max_ch = clstrs_b['channels'][unit] 

83 if max_ch < n_c_ch: # take only channels greater than `max_ch`. 

84 ch = np.arange(max_ch, max_ch + n_ch) 

85 elif (max_ch + n_c_ch) > n_ch_probe: # take only channels less than `max_ch`. 

86 ch = np.arange(max_ch - n_ch, max_ch) 

87 else: # take `n_c_ch` around `max_ch`. 

88 ch = np.arange(max_ch - n_c_ch, max_ch + n_c_ch) 

89 # Get unit template across `ch` and extract waveforms from `ephys_file`. 

90 tmplt_wfs = tmplts_b['waveforms'] 

91 unit_tmplt = tmplt_wfs[unit, :, ch].T 

92 wf_t = tmplt_wfs.shape[1] / (sr / 1000) # duration (ms) of each waveform 

93 wf = extract_waveforms(ephys_file=ephys_file, ts=ts, ch=ch, t=wf_t, sr=sr, 

94 n_ch_probe=n_ch_probe, dtype='int16', car=car) 

95 

96 # Compute center-of-mass: # 

97 ch_depths = chnls_b['localCoordinates'][[ch], [1]] 

98 d = np.zeros_like(ts) # depths array 

99 # Compute normalized dot product of (waveforms,unit_template) across `ch`, 

100 # and get center-of-mass, `c_o_m`, of these dot products (one dot for each ch) 

101 for spk in range(len(ts)): 

102 dot_wf_template = np.sum(wf[spk, :, :] * unit_tmplt, axis=0) 

103 dot_wf_template += np.abs(np.min(dot_wf_template)) 

104 dot_wf_template /= np.max(dot_wf_template) 

105 c_o_m = (1 / np.sum(dot_wf_template)) * np.sum(dot_wf_template * ch_depths) 

106 d[spk] = c_o_m 

107 return d