Skip to content

Commit 165ed2d

Browse files
authored
Merge branch 'py-why:main' into main
2 parents 509ebe7 + 9dc0365 commit 165ed2d

File tree

19 files changed

+1456
-33
lines changed

19 files changed

+1456
-33
lines changed

causallearn/score/LocalScoreFunctionClass.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
self.parameters = parameters
3030
self.score_cache = {}
3131

32-
if self.local_score_fun == local_score_BIC_from_cov:
32+
if self.local_score_fun.__name__ == 'local_score_BIC_from_cov':
3333
self.cov = np.cov(self.data.T)
3434
self.n = self.data.shape[0]
3535

@@ -40,15 +40,15 @@ def score(self, i: int, PAi: List[int]) -> float:
4040
hash_key = tuple(sorted(PAi))
4141

4242
if not self.score_cache[i].__contains__(hash_key):
43-
if self.local_score_fun == local_score_BIC_from_cov:
43+
if self.local_score_fun.__name__ == 'local_score_BIC_from_cov':
4444
self.score_cache[i][hash_key] = self.local_score_fun((self.cov, self.n), i, PAi, self.parameters)
4545
else:
4646
self.score_cache[i][hash_key] = self.local_score_fun(self.data, i, PAi, self.parameters)
4747

4848
return self.score_cache[i][hash_key]
4949

5050
def score_nocache(self, i: int, PAi: List[int]) -> float:
51-
if self.local_score_fun == local_score_BIC_from_cov:
51+
if self.local_score_fun.__name__ == 'local_score_BIC_from_cov':
5252
return self.local_score_fun((self.cov, self.n), i, PAi, self.parameters)
5353
else:
54-
return self.local_score_fun(self.data, i, PAi, self.parameters)
54+
return self.local_score_fun(self.data, i, PAi, self.parameters)

causallearn/search/PermutationBased/BOSS.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def boss(
5353
if n < p:
5454
warnings.warn("The number of features is much larger than the sample size!")
5555

56-
if score_func == "local_score_CV_general":
56+
if score_func == "local_score_CV_general":
5757
# % k-fold negative cross validated likelihood based on regression in RKHS
5858
if parameters is None:
5959
parameters = {
@@ -63,13 +63,13 @@ def boss(
6363
localScoreClass = LocalScoreClass(
6464
data=X, local_score_fun=local_score_cv_general, parameters=parameters
6565
)
66-
elif score_func == "local_score_marginal_general":
66+
elif score_func == "local_score_marginal_general":
6767
# negative marginal likelihood based on regression in RKHS
6868
parameters = {}
6969
localScoreClass = LocalScoreClass(
7070
data=X, local_score_fun=local_score_marginal_general, parameters=parameters
7171
)
72-
elif score_func == "local_score_CV_multi":
72+
elif score_func == "local_score_CV_multi":
7373
# k-fold negative cross validated likelihood based on regression in RKHS
7474
# for data with multi-variate dimensions
7575
if parameters is None:
@@ -83,7 +83,7 @@ def boss(
8383
localScoreClass = LocalScoreClass(
8484
data=X, local_score_fun=local_score_cv_multi, parameters=parameters
8585
)
86-
elif score_func == "local_score_marginal_multi":
86+
elif score_func == "local_score_marginal_multi":
8787
# negative marginal likelihood based on regression in RKHS
8888
# for data with multi-variate dimensions
8989
if parameters is None:
@@ -93,22 +93,22 @@ def boss(
9393
localScoreClass = LocalScoreClass(
9494
data=X, local_score_fun=local_score_marginal_multi, parameters=parameters
9595
)
96-
elif score_func == "local_score_BIC":
96+
elif score_func == "local_score_BIC":
9797
# SEM BIC score
98-
warnings.warn("Please use 'local_score_BIC_from_cov' instead")
98+
warnings.warn("Using 'local_score_BIC_from_cov' instead for efficiency")
9999
if parameters is None:
100100
parameters = {"lambda_value": 2}
101101
localScoreClass = LocalScoreClass(
102-
data=X, local_score_fun=local_score_BIC, parameters=parameters
102+
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
103103
)
104-
elif score_func == "local_score_BIC_from_cov":
104+
elif score_func == "local_score_BIC_from_cov":
105105
# SEM BIC score
106106
if parameters is None:
107107
parameters = {"lambda_value": 2}
108108
localScoreClass = LocalScoreClass(
109109
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
110110
)
111-
elif score_func == "local_score_BDeu":
111+
elif score_func == "local_score_BDeu":
112112
# BDeu score
113113
localScoreClass = LocalScoreClass(
114114
data=X, local_score_fun=local_score_BDeu, parameters=None
@@ -204,4 +204,4 @@ def better_mutation(v, order, gsts):
204204
order.remove(v)
205205
order.insert(best - int(best > i), v)
206206

207-
return True
207+
return True

causallearn/search/PermutationBased/GRaSP.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
local_score_marginal_general,
1717
local_score_marginal_multi,
1818
)
19-
from causallearn.search.PermutationBased.gst import GST;
19+
from causallearn.search.PermutationBased.gst import GST
2020
from causallearn.score.LocalScoreFunctionClass import LocalScoreClass
2121
from causallearn.utils.DAG2CPDAG import dag2cpdag
2222

@@ -111,7 +111,7 @@ def grasp(
111111
if n < p:
112112
warnings.warn("The number of features is much larger than the sample size!")
113113

114-
if score_func == "local_score_CV_general":
114+
if score_func == "local_score_CV_general":
115115
# k-fold negative cross validated likelihood based on regression in RKHS
116116
if parameters is None:
117117
parameters = {
@@ -127,7 +127,7 @@ def grasp(
127127
localScoreClass = LocalScoreClass(
128128
data=X, local_score_fun=local_score_marginal_general, parameters=parameters
129129
)
130-
elif score_func == "local_score_CV_multi":
130+
elif score_func == "local_score_CV_multi":
131131
# k-fold negative cross validated likelihood based on regression in RKHS
132132
# for data with multi-variate dimensions
133133
if parameters is None:
@@ -141,7 +141,7 @@ def grasp(
141141
localScoreClass = LocalScoreClass(
142142
data=X, local_score_fun=local_score_cv_multi, parameters=parameters
143143
)
144-
elif score_func == "local_score_marginal_multi":
144+
elif score_func == "local_score_marginal_multi":
145145
# negative marginal likelihood based on regression in RKHS
146146
# for data with multi-variate dimensions
147147
if parameters is None:
@@ -151,22 +151,22 @@ def grasp(
151151
localScoreClass = LocalScoreClass(
152152
data=X, local_score_fun=local_score_marginal_multi, parameters=parameters
153153
)
154-
elif score_func == "local_score_BIC":
154+
elif score_func == "local_score_BIC":
155155
# SEM BIC score
156-
warnings.warn("Please use 'local_score_BIC_from_cov' instead")
156+
warnings.warn("Using 'local_score_BIC_from_cov' instead for efficiency")
157157
if parameters is None:
158158
parameters = {"lambda_value": 2}
159159
localScoreClass = LocalScoreClass(
160-
data=X, local_score_fun=local_score_BIC, parameters=parameters
160+
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
161161
)
162-
elif score_func == "local_score_BIC_from_cov":
162+
elif score_func == "local_score_BIC_from_cov":
163163
# SEM BIC score
164164
if parameters is None:
165165
parameters = {"lambda_value": 2}
166166
localScoreClass = LocalScoreClass(
167167
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
168168
)
169-
elif score_func == "local_score_BDeu":
169+
elif score_func == "local_score_BDeu":
170170
# BDeu score
171171
localScoreClass = LocalScoreClass(
172172
data=X, local_score_fun=local_score_BDeu, parameters=None
@@ -204,7 +204,7 @@ def grasp(
204204
sys.stdout.flush()
205205

206206
runtime = time.perf_counter() - runtime
207-
207+
208208
if verbose:
209209
sys.stdout.write("\nGRaSP completed in: %.2fs \n" % runtime)
210210
sys.stdout.flush()

causallearn/search/ScoreBased/CALM.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
from causallearn.utils.MarkovNetwork.iamb import iamb_markov_network
5+
from causallearn.utils.CALMUtils import *
6+
from causallearn.graph.GeneralGraph import GeneralGraph
7+
from causallearn.graph.GraphNode import GraphNode
8+
from typing import Any, Dict
9+
from scipy.special import expit as sigmoid
10+
11+
torch.set_default_dtype(torch.double)
12+
13+
def calm(
14+
X: np.ndarray,
15+
lambda1: float = 0.005,
16+
alpha: float = 0.01,
17+
tau: float = 0.5,
18+
rho_init: float = 1e-5,
19+
rho_mult: float = 3,
20+
htol: float = 1e-8,
21+
subproblem_iter: int = 40000,
22+
standardize: bool = False,
23+
device: str = 'cpu'
24+
) -> Dict[str, Any]:
25+
"""
26+
Perform the CALM (Continuous and Acyclicity-constrained L0-penalized likelihood with estimated Moral graph) algorithm.
27+
28+
Parameters
29+
----------
30+
X : numpy.ndarray
31+
Input dataset of shape (n, d), where n is the number of samples,
32+
and d is the number of variables.
33+
lambda1 : float, optional
34+
Coefficient for the approximated L0 penalty, which encourages sparsity in the learned graph. Default is 0.005.
35+
alpha : float, optional
36+
Significance level for conditional independence tests. Default is 0.01.
37+
tau : float, optional
38+
Temperature parameter for the Gumbel-Sigmoid. Default is 0.5.
39+
rho_init : float, optional
40+
Initial value of the penalty parameter for the acyclicity constraint. Default is 1e-5.
41+
rho_mult : float, optional
42+
Multiplication factor for rho in each iteration. Default is 3.
43+
htol : float, optional
44+
Tolerance level for acyclicity constraint. Default is 1e-8.
45+
subproblem_iter : int, optional
46+
Number of iterations for subproblem optimization. Default is 40000.
47+
standardize : bool, optional
48+
Whether to standardize the input data (mean=0, variance=1). Default is False.
49+
device : str, optional
50+
The device to use for computation ('cpu' or 'cuda'). Default is 'cpu'.
51+
52+
Returns
53+
-------
54+
Record : dict
55+
A dictionary containing:
56+
- Record['G']: learned causal graph, a DAG, where: Record['G'].graph[j,i]=1 and Record['G'].graph[i,j]=-1 indicates i --> j.
57+
- Record['B_weighted']: weighted adjacency matrix of the learned causal graph.
58+
"""
59+
60+
d = X.shape[1]
61+
if standardize:
62+
mean_X = np.mean(X, axis=0, keepdims=True)
63+
std_X = np.std(X, axis=0, keepdims=True)
64+
X = (X - mean_X) / std_X
65+
else:
66+
X = X - np.mean(X, axis=0, keepdims=True)
67+
68+
# Compute the data covariance matrix
69+
cov_emp = np.cov(X.T, bias=True)
70+
71+
# Learn the moral graph using the IAMB Markov network
72+
moral_mask, _ = iamb_markov_network(X, alpha=alpha)
73+
74+
# Initialize and run the CalmModel
75+
device = torch.device(device)
76+
cov_emp = torch.from_numpy(cov_emp).to(device)
77+
moral_mask = torch.from_numpy(moral_mask).float().to(device)
78+
79+
model = CalmModel(d, moral_mask, tau=tau, lambda1=lambda1).to(device)
80+
81+
# Optimization loop
82+
rho = rho_init
83+
for _ in range(100):
84+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
85+
for _ in range(subproblem_iter):
86+
optimizer.zero_grad()
87+
loss = model.compute_loss(cov_emp, rho)
88+
loss.backward(retain_graph=True)
89+
optimizer.step()
90+
91+
with torch.no_grad():
92+
B_logit_copy = model.B_logit.detach().clone()
93+
B_logit_copy[model.moral_mask == 0] = float('-inf')
94+
h_sigmoid = model.compute_h(torch.sigmoid(B_logit_copy / model.tau))
95+
96+
rho *= rho_mult
97+
if h_sigmoid.item() <= htol or rho > 1e+16:
98+
break
99+
100+
# Extract the final binary and weighted adjacency matrices
101+
params_est = model.get_params()
102+
B_bin, B_weighted = params_est['B_bin'], params_est['B']
103+
104+
node_names = [("X%d" % (i + 1)) for i in range(d)]
105+
nodes = [GraphNode(name) for name in node_names]
106+
G = GeneralGraph(nodes)
107+
108+
# Add edges to the GeneralGraph based on B_bin
109+
for i in range(d):
110+
for j in range(d):
111+
if B_bin[i, j] == 1:
112+
G.add_directed_edge(nodes[i], nodes[j])
113+
114+
Record = {
115+
"G": G, # GeneralGraph object representing the learned causal graph, a DAG
116+
"B_weighted": B_weighted # Weighted adjacency matrix of the learned graph
117+
}
118+
119+
return Record
120+
121+
class CalmModel(nn.Module):
122+
"""
123+
The CALM model
124+
125+
Parameters
126+
----------
127+
d : int
128+
Number of variables/nodes in the graph.
129+
moral_mask : torch.Tensor
130+
Binary mask representing the moral graph structure, used to restrict possible edges.
131+
tau : float, optional
132+
Temperature parameter for the Gumbel-Sigmoid sampling, controlling the sparsity approximation. Default is 0.5.
133+
lambda1 : float, optional
134+
Coefficient for the approximated L0 penalty (sparsity term). Default is 0.005.
135+
"""
136+
def __init__(self, d, moral_mask, tau=0.5, lambda1=0.005):
137+
super(CalmModel, self).__init__()
138+
self.d = d
139+
self.moral_mask = moral_mask
140+
self.tau = tau
141+
self.lambda1 = lambda1
142+
self._init_params()
143+
144+
def _init_params(self):
145+
"""Initialize parameters"""
146+
self.B_param = nn.Parameter(
147+
torch.FloatTensor(self.d, self.d).uniform_(-0.001, 0.001).to(self.moral_mask.device)
148+
)
149+
self.B_logit = nn.Parameter(
150+
torch.zeros(self.d, self.d).to(self.moral_mask.device)
151+
)
152+
153+
def sample_mask(self):
154+
"""
155+
Samples a binary mask B_mask based on the Gumbel-Sigmoid approximation.
156+
Applies the moral graph mask to restrict possible edges.
157+
"""
158+
B_mask = gumbel_sigmoid(self.B_logit, tau=self.tau)
159+
B_mask = B_mask * self.moral_mask
160+
return B_mask
161+
162+
@torch.no_grad()
163+
def get_params(self):
164+
"""
165+
Returns the estimated adjacency matrix B_bin (binary) and B (weighted), thresholding at 0.5.
166+
"""
167+
threshold = 0.5
168+
B_param = self.B_param.cpu().detach().numpy()
169+
B_logit = self.B_logit.cpu().detach().numpy()
170+
B_logit[self.moral_mask.cpu().numpy() == 0] = float('-inf')
171+
B_bin = sigmoid(B_logit / self.tau)
172+
B_bin[B_bin < threshold] = 0
173+
B_bin[B_bin >= threshold] = 1
174+
B = B_bin * B_param
175+
params = {'B': B, 'B_bin': B_bin}
176+
return params
177+
178+
def compute_likelihood(self, B, cov_emp):
179+
"""
180+
Computes the likelihood-based objective function for non-equal noise variance (NV) assumption.
181+
"""
182+
I = torch.eye(self.d, device=self.B_param.device)
183+
residuals = torch.diagonal((I - B).T @ cov_emp @ (I - B))
184+
likelihood = 0.5 * torch.sum(torch.log(residuals)) - torch.linalg.slogdet(I - B)[1]
185+
return likelihood
186+
187+
def compute_sparsity(self, B_mask):
188+
"""
189+
Computes the sparsity penalty (approximated L0 penalty) by summing the binary entries in B_mask.
190+
"""
191+
return B_mask.sum()
192+
193+
def compute_h(self, B_mask):
194+
"""
195+
Computes the DAG constraint term, adapted from the DAG constraint formulation
196+
in Yu et al. (2019).
197+
"""
198+
return torch.trace(matrix_poly(B_mask, self.d, self.B_param.device)) - self.d
199+
200+
def compute_loss(self, cov_emp, rho):
201+
"""
202+
Combines likelihood, approximated L0 penalty (sparsity), and DAG constraint terms into the final loss function.
203+
"""
204+
B_mask = self.sample_mask()
205+
B = B_mask * self.B_param
206+
likelihood = self.compute_likelihood(B, cov_emp)
207+
sparsity = self.lambda1 * self.compute_sparsity(B_mask)
208+
h = self.compute_h(B_mask)
209+
loss = likelihood + sparsity + 0.5 * rho * h**2
210+
return loss
211+
212+

causallearn/search/ScoreBased/ExactSearch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ def bic_score_node(X, i, structure):
365365
b=X[:, i],
366366
rcond=None)
367367
bic = n * np.log(residual / n) + len(structure) * np.log(n)
368+
if bic.size == 0:
369+
return NEGINF # Return negative infinity if bic is empty
368370
return bic.item()
369371

370372

0 commit comments

Comments
 (0)