Coverage for brainbox/quality/permutation_test.py: 0%

40 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-07-08 17:16 +0100

1""" 

2Quality control for arbitrary metrics, using permutation testing. 

3 

4Written by Sebastian Bruijns 

5""" 

6 

7import numpy as np 

8import time 

9import matplotlib.pyplot as plt 

10# TODO: take in eids and download data yourself? 

11 

12 

13def permut_test(data1, data2, metric, n_permut=1000, show=False, title=None): 

14 """ 

15 Compute the probability of observating metric difference for datasets, via permutation testing. 

16 

17 We're taking absolute values of differences, because the order of dataset input shouldn't 

18 matter 

19 We're only computing means, what if we want to apply a more complicated function to the 

20 permutation result? 

21 Pay attention to always give one list (even if its just one dataset, but then it doesn't make 

22 sense anyway...) 

23 

24 Parameters 

25 ---------- 

26 data1 : array-like 

27 First data set, list or array of data-entities to use for permutation test 

28 (make data2 optional and then permutation test more similar to tuning sensitivity?) 

29 data2 : array-like 

30 Second data set, also list or array of data-entities to use for permutation test 

31 metric : function, array-like -> float 

32 Metric to use for permutation test, will be used to reduce elements of data1 and data2 

33 to one number 

34 n_permut : integer (optional) 

35 Number of perumtations to use for test 

36 plot : Boolean (optional) 

37 Whether or not to show a plot of the permutation distribution and a marker for the position 

38 of the true difference in relation to this distribution 

39 

40 Returns 

41 ------- 

42 p : float 

43 p-value of true difference in permutation distribution 

44 

45 See Also 

46 -------- 

47 TODO: 

48 

49 Examples 

50 -------- 

51 TODO: 

52 """ 

53 # Calculate metrics and true difference between groups 

54 print('data1') 

55 print(data1) 

56 metrics1 = [metric(d) for d in data1] 

57 print('metrics1') 

58 print(metrics1) 

59 metrics2 = [metric(d) for d in data2] 

60 true_diff = np.abs(np.mean(metrics1) - np.mean(metrics2)) 

61 

62 # Prepare permutations 

63 size1 = len(metrics1) 

64 diffs = np.concatenate((metrics1, metrics2)) 

65 permutations = np.zeros((n_permut, diffs.size), dtype=np.int32) 

66 

67 # Create permutations, could be parallelized or vectorized in principle, but unclear how 

68 indizes = np.arange(diffs.size) 

69 for i in range(n_permut): 

70 np.random.shuffle(indizes) 

71 permutations[i] = indizes 

72 

73 permut_diffs = np.abs(np.mean(diffs[permutations[:, :size1]], axis=1) - 

74 np.mean(diffs[permutations[:, size1:]], axis=1)) 

75 p = len(permut_diffs[permut_diffs > true_diff]) / n_permut 

76 

77 if show or title: 

78 plot_permut_test(permut_diffs=permut_diffs, true_diff=true_diff, p=p, title=title) 

79 

80 return p 

81 

82 

83def plot_permut_test(permut_diffs, true_diff, p, title=None): 

84 """Plot permutation test result.""" 

85 n, _, _ = plt.hist(permut_diffs) 

86 plt.plot(true_diff, np.max(n) / 20, '*r', markersize=12) 

87 

88 # Prettify plot 

89 plt.gca().spines['top'].set_visible(False) 

90 plt.gca().spines['right'].set_visible(False) 

91 plt.title("p = {}".format(p)) 

92 

93 if title: 

94 plt.savefig(title + '.png') 

95 plt.close() 

96 

97 

98if __name__ == '__main__': 

99 rng = np.random.RandomState(2) 

100 data1 = rng.normal(0, 1, (23, 5)) 

101 data2 = rng.normal(0.1, 1, (32, 5)) 

102 t = time.time() 

103 p = permut_test(data1, data2, np.mean, plot=True) 

104 print(time.time() - t) 

105 print(p)