Skip to content

Commit d5127c0

Browse files
committed
fix: add explicit normalization
1 parent cdf01cf commit d5127c0

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

causallearn/utils/FastKCI/FastKCI.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ def partition_data(self):
9898
ll = np.tile(np.log(pi_j), (self.n, 1))
9999
for k in range(self.K):
100100
ll[:, k] += stats.multivariate_normal.logpdf(self.data_z, mu_k[k, :], cov=sigma_k, allow_singular=True)
101-
Z = np.array([np.random.multinomial(1, np.exp(ll[n, :]-logsumexp(ll[n, :]))).argmax() for n in range(self.n)])
101+
102+
ll = np.exp(ll - logsumexp(ll, axis=1, keepdims=True))
103+
ll = ll / ll.sum(axis=1, keepdims=True)
104+
105+
Z = np.array([np.random.multinomial(1, ll[n, :]).argmax() for n in range(self.n)])
102106
le = LabelEncoder()
103107
Z = le.fit_transform(Z)
104108
return Z
@@ -414,7 +418,11 @@ def partition_data(self):
414418
ll = np.tile(np.log(pi_j), (self.n, 1))
415419
for k in range(self.K):
416420
ll[:, k] += stats.multivariate_normal.logpdf(self.data_y, mu_k[k, :], cov=sigma_k, allow_singular=True)
417-
Z = np.array([np.random.multinomial(1, np.exp(ll[n, :]-logsumexp(ll[n, :]))).argmax() for n in range(self.n)])
421+
422+
ll = np.exp(ll - logsumexp(ll, axis=1, keepdims=True))
423+
ll = ll / ll.sum(axis=1, keepdims=True)
424+
425+
Z = np.array([np.random.multinomial(1, ll[n, :]).argmax() for n in range(self.n)])
418426
prop_Y = np.take_along_axis(ll, Z[:, None], axis=1).sum()
419427
le = LabelEncoder()
420428
Z = le.fit_transform(Z)

0 commit comments

Comments
 (0)