Skip to content

Commit fedcbc3

Browse files
committed
initial upload for CALM algorithm
1 parent d450dd8 commit fedcbc3

File tree

4 files changed

+288
-0
lines changed

4 files changed

+288
-0
lines changed

causallearn/search/ScoreBased/CALM.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
from causallearn.utils.MarkovNetwork.iamb import iamb_markov_network
5+
from causallearn.utils.CALMUtils import *
6+
from causallearn.graph.GeneralGraph import GeneralGraph
7+
from causallearn.graph.GraphNode import GraphNode
8+
from typing import Any, Dict
9+
from scipy.special import expit as sigmoid
10+
11+
torch.set_default_dtype(torch.double)
12+
13+
def calm(
14+
X: np.ndarray,
15+
lambda1: float = 0.005,
16+
alpha: float = 0.01,
17+
tau: float = 0.5,
18+
rho_init: float = 1e-5,
19+
rho_mult: float = 3,
20+
htol: float = 1e-8,
21+
subproblem_iter: int = 40000,
22+
standardize: bool = False,
23+
device: str = 'cpu'
24+
) -> Dict[str, Any]:
25+
"""
26+
Perform the CALM (Continuous and Acyclicity-constrained L0-penalized likelihood with estimated Moral graph) algorithm.
27+
28+
Parameters
29+
----------
30+
X : numpy.ndarray
31+
Input dataset of shape (n, d), where n is the number of samples,
32+
and d is the number of variables.
33+
lambda1 : float, optional
34+
Coefficient for the approximated L0 penalty, which encourages sparsity in the learned graph. Default is 0.005.
35+
alpha : float, optional
36+
Significance level for conditional independence tests. Default is 0.01.
37+
tau : float, optional
38+
Temperature parameter for the Gumbel-Sigmoid. Default is 0.5.
39+
rho_init : float, optional
40+
Initial value of the penalty parameter for the acyclicity constraint. Default is 1e-5.
41+
rho_mult : float, optional
42+
Multiplication factor for rho in each iteration. Default is 3.
43+
htol : float, optional
44+
Tolerance level for acyclicity constraint. Default is 1e-8.
45+
subproblem_iter : int, optional
46+
Number of iterations for subproblem optimization. Default is 40000.
47+
standardize : bool, optional
48+
Whether to standardize the input data (mean=0, variance=1). Default is False.
49+
device : str, optional
50+
The device to use for computation ('cpu' or 'cuda'). Default is 'cpu'.
51+
52+
Returns
53+
-------
54+
Record : dict
55+
A dictionary containing:
56+
- Record['G']: learned causal graph, a DAG, where: Record['G'].graph[j,i]=1 and Record['G'].graph[i,j]=-1 indicates i --> j.
57+
- Record['B_weighted']: weighted adjacency matrix of the learned causal graph.
58+
"""
59+
60+
d = X.shape[1]
61+
if standardize:
62+
mean_X = np.mean(X, axis=0, keepdims=True)
63+
std_X = np.std(X, axis=0, keepdims=True)
64+
X = (X - mean_X) / std_X
65+
else:
66+
X = X - np.mean(X, axis=0, keepdims=True)
67+
68+
# Compute the data covariance matrix
69+
cov_emp = np.cov(X.T, bias=True)
70+
71+
# Learn the moral graph using the IAMB Markov network
72+
moral_mask, _ = iamb_markov_network(X, alpha=alpha)
73+
74+
# Initialize and run the CalmModel
75+
device = torch.device(device)
76+
cov_emp = torch.from_numpy(cov_emp).to(device)
77+
moral_mask = torch.from_numpy(moral_mask).float().to(device)
78+
79+
model = CalmModel(d, moral_mask, tau=tau, lambda1=lambda1).to(device)
80+
81+
# Optimization loop
82+
rho = rho_init
83+
for _ in range(100):
84+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
85+
for _ in range(subproblem_iter):
86+
optimizer.zero_grad()
87+
loss = model.compute_loss(cov_emp, rho)
88+
loss.backward(retain_graph=True)
89+
optimizer.step()
90+
91+
with torch.no_grad():
92+
B_logit_copy = model.B_logit.detach().clone()
93+
B_logit_copy[model.moral_mask == 0] = float('-inf')
94+
h_sigmoid = model.compute_h(torch.sigmoid(B_logit_copy / model.tau))
95+
96+
rho *= rho_mult
97+
if h_sigmoid.item() <= htol or rho > 1e+16:
98+
break
99+
100+
# Extract the final binary and weighted adjacency matrices
101+
params_est = model.get_params()
102+
B_bin, B_weighted = params_est['B_bin'], params_est['B']
103+
104+
node_names = [("X%d" % (i + 1)) for i in range(d)]
105+
nodes = [GraphNode(name) for name in node_names]
106+
G = GeneralGraph(nodes)
107+
108+
# Add edges to the GeneralGraph based on B_bin
109+
for i in range(d):
110+
for j in range(d):
111+
if B_bin[i, j] == 1:
112+
G.add_directed_edge(nodes[i], nodes[j])
113+
114+
Record = {
115+
"G": G, # GeneralGraph object representing the learned causal graph, a DAG
116+
"B_weighted": B_weighted # Weighted adjacency matrix of the learned graph
117+
}
118+
119+
return Record
120+
121+
class CalmModel(nn.Module):
122+
"""
123+
The CALM model
124+
125+
Parameters
126+
----------
127+
d : int
128+
Number of variables/nodes in the graph.
129+
moral_mask : torch.Tensor
130+
Binary mask representing the moral graph structure, used to restrict possible edges.
131+
tau : float, optional
132+
Temperature parameter for the Gumbel-Sigmoid sampling, controlling the sparsity approximation. Default is 0.5.
133+
lambda1 : float, optional
134+
Coefficient for the approximated L0 penalty (sparsity term). Default is 0.005.
135+
"""
136+
def __init__(self, d, moral_mask, tau=0.5, lambda1=0.005):
137+
super(CalmModel, self).__init__()
138+
self.d = d
139+
self.moral_mask = moral_mask
140+
self.tau = tau
141+
self.lambda1 = lambda1
142+
self._init_params()
143+
144+
def _init_params(self):
145+
"""Initialize parameters"""
146+
self.B_param = nn.Parameter(
147+
torch.FloatTensor(self.d, self.d).uniform_(-0.001, 0.001).to(self.moral_mask.device)
148+
)
149+
self.B_logit = nn.Parameter(
150+
torch.zeros(self.d, self.d).to(self.moral_mask.device)
151+
)
152+
153+
def sample_mask(self):
154+
"""
155+
Samples a binary mask B_mask based on the Gumbel-Sigmoid approximation.
156+
Applies the moral graph mask to restrict possible edges.
157+
"""
158+
B_mask = gumbel_sigmoid(self.B_logit, tau=self.tau)
159+
B_mask = B_mask * self.moral_mask
160+
return B_mask
161+
162+
@torch.no_grad()
163+
def get_params(self):
164+
"""
165+
Returns the estimated adjacency matrix B_bin (binary) and B (weighted), thresholding at 0.5.
166+
"""
167+
threshold = 0.5
168+
B_param = self.B_param.cpu().detach().numpy()
169+
B_logit = self.B_logit.cpu().detach().numpy()
170+
B_logit[self.moral_mask.cpu().numpy() == 0] = float('-inf')
171+
B_bin = sigmoid(B_logit / self.tau)
172+
B_bin[B_bin < threshold] = 0
173+
B_bin[B_bin >= threshold] = 1
174+
B = B_bin * B_param
175+
params = {'B': B, 'B_bin': B_bin}
176+
return params
177+
178+
def compute_likelihood(self, B, cov_emp):
179+
"""
180+
Computes the likelihood-based objective function for non-equal noise variance (NV) assumption.
181+
"""
182+
I = torch.eye(self.d, device=self.B_param.device)
183+
residuals = torch.diagonal((I - B).T @ cov_emp @ (I - B))
184+
likelihood = 0.5 * torch.sum(torch.log(residuals)) - torch.linalg.slogdet(I - B)[1]
185+
return likelihood
186+
187+
def compute_sparsity(self, B_mask):
188+
"""
189+
Computes the sparsity penalty (approximated L0 penalty) by summing the binary entries in B_mask.
190+
"""
191+
return B_mask.sum()
192+
193+
def compute_h(self, B_mask):
194+
"""
195+
Computes the DAG constraint term, adapted from the DAG constraint formulation
196+
in Yu et al. (2019).
197+
"""
198+
return torch.trace(matrix_poly(B_mask, self.d, self.B_param.device)) - self.d
199+
200+
def compute_loss(self, cov_emp, rho):
201+
"""
202+
Combines likelihood, approximated L0 penalty (sparsity), and DAG constraint terms into the final loss function.
203+
"""
204+
B_mask = self.sample_mask()
205+
B = B_mask * self.B_param
206+
likelihood = self.compute_likelihood(B, cov_emp)
207+
sparsity = self.lambda1 * self.compute_sparsity(B_mask)
208+
h = self.compute_h(B_mask)
209+
loss = likelihood + sparsity + 0.5 * rho * h**2
210+
return loss
211+
212+

causallearn/utils/CALMUtils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
3+
def sample_logistic(shape, out=None):
4+
U = out.resize_(shape).uniform_() if out is not None else torch.rand(shape)
5+
return torch.log(U) - torch.log(1-U)
6+
7+
8+
def gumbel_sigmoid(logits, tau=1):
9+
dims = logits.dim()
10+
logistic_noise = sample_logistic(logits.size(), out=logits.data.new())
11+
y = logits + logistic_noise
12+
return torch.sigmoid(y / tau)
13+
14+
def matrix_poly(matrix, d, device):
15+
x = torch.eye(d, device=device, dtype=matrix.dtype)+ torch.div(matrix, d)
16+
return torch.matrix_power(x, d)

causallearn/utils/MarkovNetwork/__init__.py

Whitespace-only changes.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import causallearn.utils.cit as cit
2+
import numpy as np
3+
4+
def iamb_markov_network(X, alpha=0.05):
5+
n, d = X.shape
6+
markov_network_raw = np.zeros((d, d))
7+
total_num_ci = 0
8+
cond_indep_test = cit.CIT(X, 'fisherz')
9+
# Estimate the markov blanket for each variable
10+
for i in range(d):
11+
markov_blanket, num_ci = iamb(cond_indep_test, d, i, alpha)
12+
total_num_ci += num_ci
13+
if len(markov_blanket) > 0:
14+
markov_network_raw[i, markov_blanket] = 1
15+
markov_network_raw[markov_blanket, i] = 1
16+
17+
# AND rule: (i, j) is an edge in the Markov network
18+
# if and only if i and j are in Markov blanket of each other
19+
# TODO: Check if whether we should use AND rule or OR rule
20+
markov_network = np.logical_and(markov_network_raw, markov_network_raw.T).astype(float)
21+
return markov_network, total_num_ci
22+
23+
24+
def iamb(cond_indep_test, d, target, alpha):
25+
# Modified from: https://github.com/wt-hu/pyCausalFS/blob/master/pyCausalFS/CBD/MBs/IAMB.py
26+
markov_blanket = []
27+
num_ci = 0
28+
# Forward circulate phase
29+
circulate_flag = True
30+
while circulate_flag:
31+
# if not change, forward phase of IAMB is finished.
32+
circulate_flag = False
33+
min_pval = float('inf')
34+
y = None
35+
variables = [i for i in range(d) if i != target and i not in markov_blanket]
36+
for x in variables:
37+
num_ci += 1
38+
pval = cond_indep_test(target, x, markov_blanket)
39+
# Choose maxsize of f(X:T|markov_blanket)
40+
if pval <= alpha:
41+
if pval < min_pval:
42+
min_pval = pval
43+
y = x
44+
45+
# if not condition independence the node,appended to markov_blanket
46+
if y is not None:
47+
markov_blanket.append(y)
48+
circulate_flag = True
49+
50+
# Backward circulate phase
51+
markov_blanket_temp = markov_blanket.copy()
52+
for x in markov_blanket_temp:
53+
# Exclude variable which need test p-value
54+
condition_Variables=[i for i in markov_blanket if i != x]
55+
num_ci += 1
56+
pval = cond_indep_test(target, x, condition_Variables)
57+
if pval > alpha:
58+
markov_blanket.remove(x)
59+
60+
return list(set(markov_blanket)), num_ci

0 commit comments

Comments
 (0)