Coverage for brainbox/quality/permutation_test.py: 0%
40 statements
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
« prev ^ index » next coverage.py v7.5.4, created at 2024-07-08 17:16 +0100
1"""
2Quality control for arbitrary metrics, using permutation testing.
4Written by Sebastian Bruijns
5"""
7import numpy as np
8import time
9import matplotlib.pyplot as plt
10# TODO: take in eids and download data yourself?
13def permut_test(data1, data2, metric, n_permut=1000, show=False, title=None):
14 """
15 Compute the probability of observating metric difference for datasets, via permutation testing.
17 We're taking absolute values of differences, because the order of dataset input shouldn't
18 matter
19 We're only computing means, what if we want to apply a more complicated function to the
20 permutation result?
21 Pay attention to always give one list (even if its just one dataset, but then it doesn't make
22 sense anyway...)
24 Parameters
25 ----------
26 data1 : array-like
27 First data set, list or array of data-entities to use for permutation test
28 (make data2 optional and then permutation test more similar to tuning sensitivity?)
29 data2 : array-like
30 Second data set, also list or array of data-entities to use for permutation test
31 metric : function, array-like -> float
32 Metric to use for permutation test, will be used to reduce elements of data1 and data2
33 to one number
34 n_permut : integer (optional)
35 Number of perumtations to use for test
36 plot : Boolean (optional)
37 Whether or not to show a plot of the permutation distribution and a marker for the position
38 of the true difference in relation to this distribution
40 Returns
41 -------
42 p : float
43 p-value of true difference in permutation distribution
45 See Also
46 --------
47 TODO:
49 Examples
50 --------
51 TODO:
52 """
53 # Calculate metrics and true difference between groups
54 print('data1')
55 print(data1)
56 metrics1 = [metric(d) for d in data1]
57 print('metrics1')
58 print(metrics1)
59 metrics2 = [metric(d) for d in data2]
60 true_diff = np.abs(np.mean(metrics1) - np.mean(metrics2))
62 # Prepare permutations
63 size1 = len(metrics1)
64 diffs = np.concatenate((metrics1, metrics2))
65 permutations = np.zeros((n_permut, diffs.size), dtype=np.int32)
67 # Create permutations, could be parallelized or vectorized in principle, but unclear how
68 indizes = np.arange(diffs.size)
69 for i in range(n_permut):
70 np.random.shuffle(indizes)
71 permutations[i] = indizes
73 permut_diffs = np.abs(np.mean(diffs[permutations[:, :size1]], axis=1) -
74 np.mean(diffs[permutations[:, size1:]], axis=1))
75 p = len(permut_diffs[permut_diffs > true_diff]) / n_permut
77 if show or title:
78 plot_permut_test(permut_diffs=permut_diffs, true_diff=true_diff, p=p, title=title)
80 return p
83def plot_permut_test(permut_diffs, true_diff, p, title=None):
84 """Plot permutation test result."""
85 n, _, _ = plt.hist(permut_diffs)
86 plt.plot(true_diff, np.max(n) / 20, '*r', markersize=12)
88 # Prettify plot
89 plt.gca().spines['top'].set_visible(False)
90 plt.gca().spines['right'].set_visible(False)
91 plt.title("p = {}".format(p))
93 if title:
94 plt.savefig(title + '.png')
95 plt.close()
98if __name__ == '__main__':
99 rng = np.random.RandomState(2)
100 data1 = rng.normal(0, 1, (23, 5))
101 data2 = rng.normal(0.1, 1, (32, 5))
102 t = time.time()
103 p = permut_test(data1, data2, np.mean, plot=True)
104 print(time.time() - t)
105 print(p)