|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +import torch.nn as nn |
| 4 | +from causallearn.utils.MarkovNetwork.iamb import iamb_markov_network |
| 5 | +from causallearn.utils.CALMUtils import * |
| 6 | +from causallearn.graph.GeneralGraph import GeneralGraph |
| 7 | +from causallearn.graph.GraphNode import GraphNode |
| 8 | +from typing import Any, Dict |
| 9 | +from scipy.special import expit as sigmoid |
| 10 | + |
| 11 | +torch.set_default_dtype(torch.double) |
| 12 | + |
| 13 | +def calm( |
| 14 | + X: np.ndarray, |
| 15 | + lambda1: float = 0.005, |
| 16 | + alpha: float = 0.01, |
| 17 | + tau: float = 0.5, |
| 18 | + rho_init: float = 1e-5, |
| 19 | + rho_mult: float = 3, |
| 20 | + htol: float = 1e-8, |
| 21 | + subproblem_iter: int = 40000, |
| 22 | + standardize: bool = False, |
| 23 | + device: str = 'cpu' |
| 24 | +) -> Dict[str, Any]: |
| 25 | + """ |
| 26 | + Perform the CALM (Continuous and Acyclicity-constrained L0-penalized likelihood with estimated Moral graph) algorithm. |
| 27 | +
|
| 28 | + Parameters |
| 29 | + ---------- |
| 30 | + X : numpy.ndarray |
| 31 | + Input dataset of shape (n, d), where n is the number of samples, |
| 32 | + and d is the number of variables. |
| 33 | + lambda1 : float, optional |
| 34 | + Coefficient for the approximated L0 penalty, which encourages sparsity in the learned graph. Default is 0.005. |
| 35 | + alpha : float, optional |
| 36 | + Significance level for conditional independence tests. Default is 0.01. |
| 37 | + tau : float, optional |
| 38 | + Temperature parameter for the Gumbel-Sigmoid. Default is 0.5. |
| 39 | + rho_init : float, optional |
| 40 | + Initial value of the penalty parameter for the acyclicity constraint. Default is 1e-5. |
| 41 | + rho_mult : float, optional |
| 42 | + Multiplication factor for rho in each iteration. Default is 3. |
| 43 | + htol : float, optional |
| 44 | + Tolerance level for acyclicity constraint. Default is 1e-8. |
| 45 | + subproblem_iter : int, optional |
| 46 | + Number of iterations for subproblem optimization. Default is 40000. |
| 47 | + standardize : bool, optional |
| 48 | + Whether to standardize the input data (mean=0, variance=1). Default is False. |
| 49 | + device : str, optional |
| 50 | + The device to use for computation ('cpu' or 'cuda'). Default is 'cpu'. |
| 51 | +
|
| 52 | + Returns |
| 53 | + ------- |
| 54 | + Record : dict |
| 55 | + A dictionary containing: |
| 56 | + - Record['G']: learned causal graph, a DAG, where: Record['G'].graph[j,i]=1 and Record['G'].graph[i,j]=-1 indicates i --> j. |
| 57 | + - Record['B_weighted']: weighted adjacency matrix of the learned causal graph. |
| 58 | + """ |
| 59 | + |
| 60 | + d = X.shape[1] |
| 61 | + if standardize: |
| 62 | + mean_X = np.mean(X, axis=0, keepdims=True) |
| 63 | + std_X = np.std(X, axis=0, keepdims=True) |
| 64 | + X = (X - mean_X) / std_X |
| 65 | + else: |
| 66 | + X = X - np.mean(X, axis=0, keepdims=True) |
| 67 | + |
| 68 | + # Compute the data covariance matrix |
| 69 | + cov_emp = np.cov(X.T, bias=True) |
| 70 | + |
| 71 | + # Learn the moral graph using the IAMB Markov network |
| 72 | + moral_mask, _ = iamb_markov_network(X, alpha=alpha) |
| 73 | + |
| 74 | + # Initialize and run the CalmModel |
| 75 | + device = torch.device(device) |
| 76 | + cov_emp = torch.from_numpy(cov_emp).to(device) |
| 77 | + moral_mask = torch.from_numpy(moral_mask).float().to(device) |
| 78 | + |
| 79 | + model = CalmModel(d, moral_mask, tau=tau, lambda1=lambda1).to(device) |
| 80 | + |
| 81 | + # Optimization loop |
| 82 | + rho = rho_init |
| 83 | + for _ in range(100): |
| 84 | + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
| 85 | + for _ in range(subproblem_iter): |
| 86 | + optimizer.zero_grad() |
| 87 | + loss = model.compute_loss(cov_emp, rho) |
| 88 | + loss.backward(retain_graph=True) |
| 89 | + optimizer.step() |
| 90 | + |
| 91 | + with torch.no_grad(): |
| 92 | + B_logit_copy = model.B_logit.detach().clone() |
| 93 | + B_logit_copy[model.moral_mask == 0] = float('-inf') |
| 94 | + h_sigmoid = model.compute_h(torch.sigmoid(B_logit_copy / model.tau)) |
| 95 | + |
| 96 | + rho *= rho_mult |
| 97 | + if h_sigmoid.item() <= htol or rho > 1e+16: |
| 98 | + break |
| 99 | + |
| 100 | + # Extract the final binary and weighted adjacency matrices |
| 101 | + params_est = model.get_params() |
| 102 | + B_bin, B_weighted = params_est['B_bin'], params_est['B'] |
| 103 | + |
| 104 | + node_names = [("X%d" % (i + 1)) for i in range(d)] |
| 105 | + nodes = [GraphNode(name) for name in node_names] |
| 106 | + G = GeneralGraph(nodes) |
| 107 | + |
| 108 | + # Add edges to the GeneralGraph based on B_bin |
| 109 | + for i in range(d): |
| 110 | + for j in range(d): |
| 111 | + if B_bin[i, j] == 1: |
| 112 | + G.add_directed_edge(nodes[i], nodes[j]) |
| 113 | + |
| 114 | + Record = { |
| 115 | + "G": G, # GeneralGraph object representing the learned causal graph, a DAG |
| 116 | + "B_weighted": B_weighted # Weighted adjacency matrix of the learned graph |
| 117 | + } |
| 118 | + |
| 119 | + return Record |
| 120 | + |
| 121 | +class CalmModel(nn.Module): |
| 122 | + """ |
| 123 | + The CALM model |
| 124 | +
|
| 125 | + Parameters |
| 126 | + ---------- |
| 127 | + d : int |
| 128 | + Number of variables/nodes in the graph. |
| 129 | + moral_mask : torch.Tensor |
| 130 | + Binary mask representing the moral graph structure, used to restrict possible edges. |
| 131 | + tau : float, optional |
| 132 | + Temperature parameter for the Gumbel-Sigmoid sampling, controlling the sparsity approximation. Default is 0.5. |
| 133 | + lambda1 : float, optional |
| 134 | + Coefficient for the approximated L0 penalty (sparsity term). Default is 0.005. |
| 135 | + """ |
| 136 | + def __init__(self, d, moral_mask, tau=0.5, lambda1=0.005): |
| 137 | + super(CalmModel, self).__init__() |
| 138 | + self.d = d |
| 139 | + self.moral_mask = moral_mask |
| 140 | + self.tau = tau |
| 141 | + self.lambda1 = lambda1 |
| 142 | + self._init_params() |
| 143 | + |
| 144 | + def _init_params(self): |
| 145 | + """Initialize parameters""" |
| 146 | + self.B_param = nn.Parameter( |
| 147 | + torch.FloatTensor(self.d, self.d).uniform_(-0.001, 0.001).to(self.moral_mask.device) |
| 148 | + ) |
| 149 | + self.B_logit = nn.Parameter( |
| 150 | + torch.zeros(self.d, self.d).to(self.moral_mask.device) |
| 151 | + ) |
| 152 | + |
| 153 | + def sample_mask(self): |
| 154 | + """ |
| 155 | + Samples a binary mask B_mask based on the Gumbel-Sigmoid approximation. |
| 156 | + Applies the moral graph mask to restrict possible edges. |
| 157 | + """ |
| 158 | + B_mask = gumbel_sigmoid(self.B_logit, tau=self.tau) |
| 159 | + B_mask = B_mask * self.moral_mask |
| 160 | + return B_mask |
| 161 | + |
| 162 | + @torch.no_grad() |
| 163 | + def get_params(self): |
| 164 | + """ |
| 165 | + Returns the estimated adjacency matrix B_bin (binary) and B (weighted), thresholding at 0.5. |
| 166 | + """ |
| 167 | + threshold = 0.5 |
| 168 | + B_param = self.B_param.cpu().detach().numpy() |
| 169 | + B_logit = self.B_logit.cpu().detach().numpy() |
| 170 | + B_logit[self.moral_mask.cpu().numpy() == 0] = float('-inf') |
| 171 | + B_bin = sigmoid(B_logit / self.tau) |
| 172 | + B_bin[B_bin < threshold] = 0 |
| 173 | + B_bin[B_bin >= threshold] = 1 |
| 174 | + B = B_bin * B_param |
| 175 | + params = {'B': B, 'B_bin': B_bin} |
| 176 | + return params |
| 177 | + |
| 178 | + def compute_likelihood(self, B, cov_emp): |
| 179 | + """ |
| 180 | + Computes the likelihood-based objective function for non-equal noise variance (NV) assumption. |
| 181 | + """ |
| 182 | + I = torch.eye(self.d, device=self.B_param.device) |
| 183 | + residuals = torch.diagonal((I - B).T @ cov_emp @ (I - B)) |
| 184 | + likelihood = 0.5 * torch.sum(torch.log(residuals)) - torch.linalg.slogdet(I - B)[1] |
| 185 | + return likelihood |
| 186 | + |
| 187 | + def compute_sparsity(self, B_mask): |
| 188 | + """ |
| 189 | + Computes the sparsity penalty (approximated L0 penalty) by summing the binary entries in B_mask. |
| 190 | + """ |
| 191 | + return B_mask.sum() |
| 192 | + |
| 193 | + def compute_h(self, B_mask): |
| 194 | + """ |
| 195 | + Computes the DAG constraint term, adapted from the DAG constraint formulation |
| 196 | + in Yu et al. (2019). |
| 197 | + """ |
| 198 | + return torch.trace(matrix_poly(B_mask, self.d, self.B_param.device)) - self.d |
| 199 | + |
| 200 | + def compute_loss(self, cov_emp, rho): |
| 201 | + """ |
| 202 | + Combines likelihood, approximated L0 penalty (sparsity), and DAG constraint terms into the final loss function. |
| 203 | + """ |
| 204 | + B_mask = self.sample_mask() |
| 205 | + B = B_mask * self.B_param |
| 206 | + likelihood = self.compute_likelihood(B, cov_emp) |
| 207 | + sparsity = self.lambda1 * self.compute_sparsity(B_mask) |
| 208 | + h = self.compute_h(B_mask) |
| 209 | + loss = likelihood + sparsity + 0.5 * rho * h**2 |
| 210 | + return loss |
| 211 | + |
| 212 | + |
0 commit comments