Skip to content

Commit f9caf66

Browse files
authored
Revert "Graph operations compatible with np array"
1 parent 67637e6 commit f9caf66

File tree

3 files changed

+11
-40
lines changed

3 files changed

+11
-40
lines changed

causallearn/graph/Dag.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
from itertools import combinations
3-
from typing import List, Optional, Union
3+
from typing import List
44

55
import networkx as nx
66
import numpy as np
@@ -18,23 +18,8 @@
1818
# or latent, with at most one edge per node pair, and no edges to self.
1919
class Dag(GeneralGraph):
2020

21-
def __init__(self, nodes: Optional[List[Node]]=None, graph: Union[np.ndarray, nx.Graph, None]=None):
22-
if nodes is not None:
23-
self._init_from_nodes(nodes)
24-
elif graph is not None:
25-
if isinstance(graph, np.ndarray):
26-
nodes = [Node(node_name=str(i)) for i in range(len(graph))]
27-
self._init_from_nodes(nodes)
28-
for i in range(len(nodes)):
29-
for j in range(len(nodes)):
30-
if graph[i, j] == 1:
31-
self.add_directed_edge(nodes[i], nodes[j])
32-
else:
33-
pass
34-
else:
35-
raise ValueError("Dag.__init__() requires argument 'nodes' or 'graph'")
36-
37-
def _init_from_nodes(self, nodes: List[Node]):
21+
def __init__(self, nodes: List[Node]):
22+
3823
# for node in nodes:
3924
# if not isinstance(node, type(GraphNode)):
4025
# raise TypeError("Graphs must be instantiated with a list of GraphNodes")

causallearn/graph/Node.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,27 @@
22

33
# Represents an object with a name, node type, and position that can serve as a
44
# node in a graph.
5-
from typing import Optional
65
from causallearn.graph.NodeType import NodeType
76
from causallearn.graph.NodeVariableType import NodeVariableType
87

98

109
class Node:
11-
node_type: NodeType
12-
node_name: str
1310

14-
def __init__(self, node_name: Optional[str] = None, node_type: Optional[NodeType] = None) -> None:
15-
self.node_name = node_name
16-
self.node_type = node_type
17-
1811
# @return the name of the variable.
1912
def get_name(self) -> str:
20-
return self.node_name
13+
pass
2114

2215
# set the name of the variable
2316
def set_name(self, name: str):
24-
self.node_name = name
17+
pass
2518

2619
# @return the node type of the variable
2720
def get_node_type(self) -> NodeType:
28-
return self.node_type
21+
pass
2922

3023
# set the node type of the variable
3124
def set_node_type(self, node_type: NodeType):
32-
self.node_type = node_type
25+
pass
3326

3427
# @return the intervention type
3528
def get_node_variable_type(self) -> NodeVariableType:
@@ -42,7 +35,7 @@ def set_node_variable_type(self, var_type: NodeVariableType):
4235

4336
# @return the name of the node as its string representation
4437
def __str__(self):
45-
return self.node_name
38+
pass
4639

4740
# @return the x coordinate of the center of the node
4841
def get_center_x(self) -> int:
@@ -66,7 +59,7 @@ def set_center(self, center_x: int, center_y: int):
6659

6760
# @return a hashcode for this variable
6861
def __hash__(self):
69-
return hash(self.node_name)
62+
pass
7063

7164
# @return true iff this variable is equal to the given variable
7265
def __eq__(self, other):

causallearn/utils/DAG2CPDAG.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from typing import Union
21
import numpy as np
32

43
from causallearn.graph.Dag import Dag
@@ -7,7 +6,7 @@
76
from causallearn.graph.GeneralGraph import GeneralGraph
87

98

10-
def dag2cpdag(G: Union[Dag, np.ndarray]) -> GeneralGraph:
9+
def dag2cpdag(G: Dag) -> GeneralGraph:
1110
"""
1211
Convert a DAG to its corresponding PDAG
1312
@@ -23,13 +22,7 @@ def dag2cpdag(G: Union[Dag, np.ndarray]) -> GeneralGraph:
2322
-------
2423
Yuequn Liu@dmirlab, Wei Chen@dmirlab, Kun Zhang@CMU
2524
"""
26-
27-
if isinstance(G, np.ndarray):
28-
# convert np array to Dag graph
29-
G = Dag(graph=G)
30-
elif not isinstance(G, Dag):
31-
raise TypeError("parameter graph should be `Dag` or `np.ndarry`")
32-
25+
3326
# order the edges in G
3427
nodes_order = list(
3528
map(lambda x: G.node_map[x], G.get_causal_ordering())

0 commit comments

Comments
 (0)