Skip to content

Commit 0a70476

Browse files
Merge pull request #17 from WFHong/WFHong
fix some bugs
2 parents a081b75 + 3271cca commit 0a70476

30 files changed

+1355
-1828
lines changed

cdmir/discovery/funtional_based/LearningHierarchicalStructure/Causal_Discovery_in_LHM.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
import Utils
99

1010

11-
1211
def Causal_Discovery_LHM(data, alpha=0.01):
13-
'''
12+
"""
1413
Function: Causal Discovery in Linear Latent Hierarchical Structure
1514
Parameter
1615
data: DataFrame (pandas)
@@ -23,82 +22,72 @@ def Causal_Discovery_LHM(data, alpha=0.01):
2322
Graph (selected)
2423
the Causal graph of hierarchical structure
2524
26-
'''
27-
#Initize Variable
28-
Current_Clusters=[]
29-
PureClusters=[]
30-
GeneralPureClusters=[]
31-
ImpureClusters=[]
25+
"""
26+
# Initize Variable
27+
Current_Clusters = []
28+
PureClusters = []
29+
GeneralPureClusters = []
30+
ImpureClusters = []
3231

33-
AllCausalCluster=[]
32+
AllCausalCluster = []
3433

35-
LatentIndex={}
34+
LatentIndex = {}
3635
LatentNum = 1
3736
Ora_data = data.copy()
3837

39-
4038
while True:
4139

4240
'''Begin recursive procedure'''
4341

44-
#Phase I: Finding causal cluster and judge the purity or impurity
45-
Current_Clusters, PureClusters, ImpureClusters, PClusters=FindCausalCluster.FindCausalCluster(data, PureClusters, ImpureClusters, alpha)
42+
# Phase I: Finding causal cluster and judge the purity or impurity
43+
Current_Clusters, PureClusters, ImpureClusters, PClusters = FindCausalCluster.FindCausalCluster(data,
44+
PureClusters,
45+
ImpureClusters,
46+
alpha)
4647
AllCausalCluster = Utils.ExtendList(AllCausalCluster, PClusters)
4748
AllCausalCluster = Utils.ExtendList(AllCausalCluster, Current_Clusters)
4849

49-
#debug
50+
# debug
5051
print('Finished Finding Causal Cluster: ', Current_Clusters, PureClusters, ImpureClusters, PClusters)
5152

53+
# Phase II: Check merge rule for the learned clusters and update record variables
5254

53-
#Phase II: Check merge rule for the learned clusters and update record variables
54-
55-
Merge_Results, PureClusters, ImpureClusters, AllCausalCluster, GeneralPureClusters, LatentIndex = MC.MergeCausalCluster(Current_Clusters, PureClusters, ImpureClusters, AllCausalCluster, GeneralPureClusters, LatentIndex, data, Ora_data, alpha)
55+
Merge_Results, PureClusters, ImpureClusters, AllCausalCluster, GeneralPureClusters, LatentIndex = MC.MergeCausalCluster(
56+
Current_Clusters, PureClusters, ImpureClusters, AllCausalCluster, GeneralPureClusters, LatentIndex, data,
57+
Ora_data, alpha)
5658

5759
MergeCluster = Merge_Results[0]
5860
EarlyLearningImpureClusters = Merge_Results[1]
5961
EarlyLearningRemoveClusters = Merge_Results[2]
6062
IntroduceLatent_PureClusters = Merge_Results[3]
6163
RemainingVariables = Merge_Results[4]
6264

63-
print('Merge_Results: ',Merge_Results)
65+
print('Merge_Results: ', Merge_Results)
6466
print(LatentIndex)
6567

66-
67-
6868
if len(MergeCluster) == 0 and len(IntroduceLatent_PureClusters) == 0:
6969
print('This is nothing be learned !')
70-
if len(RemainingVariables) == 0 and len(EarlyLearningImpureClusters) !=0:
70+
if len(RemainingVariables) == 0 and len(EarlyLearningImpureClusters) != 0:
7171
print('There are something wrong! In the merger Results !!', EarlyLearningImpureClusters)
7272
exit(-1)
7373
elif len(RemainingVariables) <= 3:
7474
print('Recursive Procedure Finished ! The structure is identified up to a Markov equivalent class.')
7575
print(LatentIndex, ImpureClusters)
7676
break
7777

78-
79-
80-
#Phase III: Introduce latent variable into the graph and update the actived data set
78+
# Phase III: Introduce latent variable into the graph and update the actived data set
8179
data, LatentNum, LatentIndex = UpdataData.UpdataData(Merge_Results, LatentNum, LatentIndex, data)
8280

83-
if len(data) <=1:
81+
if len(data) <= 1:
8482
print(LatentIndex, ImpureClusters)
8583
break
8684

8785
print(data, LatentNum, LatentIndex, ImpureClusters)
8886

8987
'''End recursive procedure'''
9088

91-
MakeGraph.Make_graph_Impure(LatentIndex,ImpureClusters)
89+
MakeGraph.Make_graph_Impure(LatentIndex, ImpureClusters)
9290

93-
#Phase IV: orientation causal direction among latent variable, including latent measured
91+
# Phase IV: orientation causal direction among latent variable, including latent measured
9492
Orientation.Orientation_Cluster(Ora_data, LatentIndex, PureClusters, AllCausalCluster)
95-
#ImpureOrder = Orientation.Orientation_ImpureCluster(Ora_data, LatentIndex, PureClusters, AllCausalCluster, ImpureClusters)
96-
97-
98-
99-
100-
101-
102-
103-
104-
93+
# ImpureOrder = Orientation.Orientation_ImpureCluster(Ora_data, LatentIndex, PureClusters, AllCausalCluster, ImpureClusters)

cdmir/discovery/funtional_based/LearningHierarchicalStructure/FindCausalCluster.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import TetradMethod
88
import VanishedTest as VT
99

10+
1011
def FindCausalCluster(data, PureCluster, ImpureCluster, alpha=0.01):
11-
'''
12+
"""
1213
Function: Find Causal Cluster from the actived data set and check the purity
1314
Parameter
1415
data: DataFrame
@@ -29,55 +30,48 @@ def FindCausalCluster(data, PureCluster, ImpureCluster, alpha=0.01):
2930
Update ImpureCluster set by adding the current learned clusters
3031
PClusters: List
3132
Return only two element cluster
32-
'''
33-
LearnedClusters=[]
34-
PClusters=[]
35-
indexs=list(data.columns) #all observed data in current procedure
36-
B=indexs.copy() #remain variables
37-
ClusterLength=2
33+
"""
34+
LearnedClusters = []
35+
PClusters = []
36+
indexs = list(data.columns) # all observed data in current procedure
37+
B = indexs.copy() # remain variables
38+
ClusterLength = 2
3839
for S in itertools.combinations(list(B), ClusterLength):
3940
if TetradMethod.CheckCausalCluster(list(S), data, alpha):
4041
LearnedClusters.append(S)
4142
PClusters.append(S)
42-
B=set(B)-set(S)
43-
44-
#overlap merge, updata the causal cluster and add into PureCluster
45-
#only recall the overlap merge function and check whether the cluster with more than three elements
46-
LearnedClusters=Overlap_Merge.merge_list(LearnedClusters)
43+
B = set(B) - set(S)
4744

45+
# overlap merge, updata the causal cluster and add into PureCluster
46+
# only recall the overlap merge function and check whether the cluster with more than three elements
47+
LearnedClusters = Overlap_Merge.merge_list(LearnedClusters)
4848

4949
for S in LearnedClusters:
50-
if len(S) > 2: #Overlap merged cluster add into Purecluster
50+
if len(S) > 2: # Overlap merged cluster add into Purecluster
5151
PureCluster.append(S)
52-
else: # run the identifying pure cluster function
53-
Pure_flag=TetradMethod.JudgePureCluster(S, data, alpha)
52+
else: # run the identifying pure cluster function
53+
Pure_flag = TetradMethod.JudgePureCluster(S, data, alpha)
5454
if Pure_flag:
5555
PureCluster.append(S)
5656
else:
5757
ImpureCluster.append(S)
5858

59-
ClusterLength +=1
59+
ClusterLength += 1
6060
while len(B) >= ClusterLength and len(indexs) > (ClusterLength + 2):
6161
for S in itertools.combinations(list(B), ClusterLength):
62-
S=list(S)
62+
S = list(S)
6363
if TetradMethod.CheckCausalCluster(list(S), data, alpha):
6464
LearnedClusters.append(S)
6565
ImpureCluster.append(S)
66-
B=set(B)-set(S)
67-
ClusterLength +=1
66+
B = set(B) - set(S)
67+
ClusterLength += 1
6868

6969
return LearnedClusters, PureCluster, ImpureCluster, PClusters
7070

7171

72-
73-
74-
75-
76-
77-
78-
7972
def main():
8073
pass
8174

75+
8276
if __name__ == '__main__':
8377
main()

cdmir/discovery/funtional_based/LearningHierarchicalStructure/GIN2.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,56 @@
11
import numpy as np
22
import pandas as pd
3-
import cdmir.discovery.funtional_based.LearningHierarchicalStructure.indTest.HSIC2 as hsic
3+
import lingam.hsic as hsic
44

55
import cdmir.discovery.funtional_based.LearningHierarchicalStructure.indTest.independence as ID
66

77

8-
#GIN by fast HSIC
9-
#X=['X1','X2']
10-
#Z=['X3']
11-
#data.type=Pandas.DataFrame
12-
def GIN(X,Z,data,alpha=0.05):
13-
omega = getomega(data,X,Z)
14-
tdata= data[X]
15-
#print(tdata.T)
8+
# GIN by fast HSIC
9+
# X=['X1','X2']
10+
# Z=['X3']
11+
# data.type=Pandas.DataFrame
12+
def GIN(X, Z, data, alpha=0.05):
13+
omega = getomega(data, X, Z)
14+
tdata = data[X]
1615
result = np.dot(omega, tdata.T)
1716
for i in Z:
1817
temp = np.array(data[i])
19-
flag =hsic.test(result.T,temp,alpha)
2018

21-
if not flag:#not false == ture ---> if false
22-
#print(X,Z,flag)
19+
pvalue = hsic.hsic_test_gamma(result.T, temp, alpha)
20+
21+
if pvalue > alpha:
22+
flag = True
23+
else:
24+
flag = False
25+
26+
if not flag:
2327
return False
2428

2529
return True
2630

2731

28-
#mthod 1: estimating mutual information by k nearest neighbors (density estimation)
29-
#mthod 2: estimating mutual information by sklearn package
30-
def GIN_MI(X,Z,data,method='1'):
31-
omega = getomega(data,X,Z)
32-
tdata= data[X]
32+
# mthod 1: estimating mutual information by k nearest neighbors (density estimation)
33+
# mthod 2: estimating mutual information by sklearn package
34+
def GIN_MI(X, Z, data, method='1'):
35+
omega = getomega(data, X, Z)
36+
tdata = data[X]
3337
result = np.dot(omega, tdata.T)
34-
MIS=0
38+
MIS = 0
3539
for i in Z:
3640

3741
temp = np.array(data[i])
38-
if method =='1':
39-
mi=ID.independent(result.T,temp)
42+
if method == '1':
43+
mi = ID.independent(result.T, temp)
4044
else:
41-
mi=ID.independent11(result.T,temp)
42-
MIS+=mi
43-
MIS = MIS/len(Z)
45+
mi = ID.independent11(result.T, temp)
46+
MIS += mi
47+
MIS = MIS / len(Z)
4448

4549
return MIS
4650

4751

48-
49-
def getomega(data,X,Z):
50-
cov_m =np.cov(data,rowvar=False)
52+
def getomega(data, X, Z):
53+
cov_m = np.cov(data, rowvar=False)
5154
col = list(data.columns)
5255
Xlist = []
5356
Zlist = []
@@ -58,13 +61,12 @@ def getomega(data,X,Z):
5861
t = col.index(i)
5962
Zlist.append(t)
6063
B = cov_m[Xlist]
61-
B = B[:,Zlist]
64+
B = B[:, Zlist]
6265
A = B.T
63-
u,s,v = np.linalg.svd(A)
66+
u, s, v = np.linalg.svd(A)
6467
lens = len(X)
65-
omega =v.T[:,lens-1]
66-
omegalen=len(omega)
67-
omega=omega.reshape(1,omegalen)
68+
omega = v.T[:, lens - 1]
69+
omegalen = len(omega)
70+
omega = omega.reshape(1, omegalen)
6871

6972
return omega
70-

0 commit comments

Comments
 (0)