Skip to content

Commit 7924253

Browse files
authored
Merge pull request #194 from EvieQ01/main
Add orientation rules 567 for Augmented FCI
2 parents 0021a13 + 0e076fc commit 7924253

15 files changed

+297
-59
lines changed

causallearn/search/ConstraintBased/FCI.py

Lines changed: 206 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from queue import Queue
5-
from typing import List, Set, Tuple, Dict
5+
from typing import List, Set, Tuple, Dict, Generator
66
from numpy import ndarray
77

88
from causallearn.graph.Edge import Edge
@@ -17,6 +17,19 @@
1717
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
1818
from itertools import combinations
1919

20+
def is_uncovered_path(nodes: List[Node], G: Graph) -> bool:
21+
"""
22+
Determines whether the given path is an uncovered path in this graph.
23+
24+
A path is an uncovered path if no two nonconsecutive nodes (Vi-1 and Vi+1) in the path are
25+
adjacent.
26+
"""
27+
for i in range(len(nodes) - 2):
28+
if G.is_adjacent_to(nodes[i], nodes[i + 2]):
29+
return False
30+
return True
31+
32+
2033
def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
2134
if node == edge.get_node1():
2235
if edge.get_endpoint1() == Endpoint.TAIL or edge.get_endpoint1() == Endpoint.CIRCLE:
@@ -26,8 +39,17 @@ def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
2639
return edge.get_node1()
2740
return None
2841

42+
def traverseCircle(node: Node, edge: Edge) -> Node | None:
43+
if node == edge.get_node1():
44+
if edge.get_endpoint1() == Endpoint.CIRCLE and edge.get_endpoint2() == Endpoint.CIRCLE:
45+
return edge.get_node2()
46+
elif node == edge.get_node2():
47+
if edge.get_endpoint1() == Endpoint.CIRCLE and edge.get_endpoint2() == Endpoint.CIRCLE:
48+
return edge.get_node1()
49+
return None
50+
2951

30-
def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool:
52+
def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: ## TODO: Now it does not detect whether the path is an uncovered path
3153
Q = Queue()
3254
V = set()
3355

@@ -60,6 +82,42 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool:
6082

6183
return False
6284

85+
def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph, exclude_node: List[Node]) -> Generator[Node] | None:
86+
Q = Queue()
87+
V = set()
88+
89+
path = [node_from]
90+
91+
for node_u in G.get_adjacent_nodes(node_from):
92+
if node_u in exclude_node:
93+
continue
94+
edge = G.get_edge(node_from, node_u)
95+
node_c = traverseCircle(node_from, edge)
96+
97+
if node_c is None or node_c in exclude_node:
98+
continue
99+
100+
if not V.__contains__(node_c):
101+
V.add(node_c)
102+
Q.put((node_c, path + [node_c]))
103+
104+
while not Q.empty():
105+
node_t, path = Q.get_nowait()
106+
if node_t == node_to and is_uncovered_path(path, G):
107+
yield path
108+
109+
for node_u in G.get_adjacent_nodes(node_t):
110+
edge = G.get_edge(node_t, node_u)
111+
node_c = traverseCircle(node_t, edge)
112+
113+
if node_c is None or node_c in exclude_node:
114+
continue
115+
116+
if not V.__contains__(node_c):
117+
V.add(node_c)
118+
Q.put((node_c, path + [node_c]))
119+
120+
63121

64122
def existOnePathWithPossibleParents(previous, node_w: Node, node_x: Node, node_b: Node, graph: Graph) -> bool:
65123
if node_w == node_x:
@@ -371,6 +429,131 @@ def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: Backgrou
371429
changeFlag = True
372430
return changeFlag
373431

432+
def ruleR5(graph: Graph, changeFlag: bool,
433+
verbose: bool = False) -> bool:
434+
"""
435+
Rule R5 of the FCI algorithm.
436+
by Jiji Zhang, 2008, "On the completeness of orientation rules for causal discovery in the presence of latent confounders and selection bias"]
437+
438+
This function orients any edge that is part of an uncovered circle path between two nodes A and B,
439+
if such a path exists. The path must start and end with a circle edge and must be uncovered, i.e. the
440+
nodes on the path must not be adjacent to A or B. The orientation of the edges on the path is set to
441+
double tail.
442+
"""
443+
nodes = graph.get_nodes()
444+
def orient_on_path_helper(path, node_A, node_B):
445+
# orient A - C, D - B
446+
edge = graph.get_edge(node_A, path[0])
447+
graph.remove_edge(edge)
448+
graph.add_edge(Edge(node_A, path[0], Endpoint.TAIL, Endpoint.TAIL))
449+
450+
edge = graph.get_edge(node_B, path[-1])
451+
graph.remove_edge(edge)
452+
graph.add_edge(Edge(node_B, path[-1], Endpoint.TAIL, Endpoint.TAIL))
453+
if verbose:
454+
print("Orienting edge A - C (Double tail): " + graph.get_edge(node_A, path[0]).__str__())
455+
print("Orienting edge B - D (Double tail): " + graph.get_edge(node_B, path[-1]).__str__())
456+
457+
# orient everything on the path to both tails
458+
for i in range(len(path) - 1):
459+
edge = graph.get_edge(path[i], path[i + 1])
460+
graph.remove_edge(edge)
461+
graph.add_edge(Edge(path[i], path[i + 1], Endpoint.TAIL, Endpoint.TAIL))
462+
if verbose:
463+
print("Orienting edge (Double tail): " + graph.get_edge(path[i], path[i + 1]).__str__())
464+
465+
for node_B in nodes:
466+
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)
467+
468+
for node_A in intoBCircles:
469+
found_paths_between_AB = []
470+
if graph.get_endpoint(node_B, node_A) != Endpoint.CIRCLE:
471+
continue
472+
else:
473+
# Check if there is an uncovered circle path between A and B (A o-o C .. D o-o B)
474+
# s.t. A is not adjacent to D and B is not adjacent to C
475+
a_node_idx = graph.node_map[node_A]
476+
b_node_idx = graph.node_map[node_B]
477+
a_adj_nodes = graph.get_adjacent_nodes(node_A)
478+
b_adj_nodes = graph.get_adjacent_nodes(node_B)
479+
480+
# get the adjacent nodes with circle edges of A and B
481+
a_circle_adj_nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= b_node_idx
482+
and graph.get_endpoint(node, node_A) == Endpoint.CIRCLE and graph.get_endpoint(node_A, node) == Endpoint.CIRCLE]
483+
b_circle_adj_nodes_set = [node for node in b_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= b_node_idx
484+
and graph.get_endpoint(node, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node) == Endpoint.CIRCLE]
485+
486+
# get the adjacent nodes with circle edges of A and B that is non adjacent to B and A, respectively
487+
for node_C in a_circle_adj_nodes_set:
488+
if graph.is_adjacent_to(node_B, node_C):
489+
continue
490+
for node_D in b_circle_adj_nodes_set:
491+
if graph.is_adjacent_to(node_A, node_D):
492+
continue
493+
paths = GetUncoveredCirclePath(node_from=node_C, node_to=node_D, G=graph, exclude_node=[node_A, node_B]) # get the uncovered circle path between C and D, excluding A and B
494+
found_paths_between_AB.append(paths)
495+
496+
# Orient the uncovered circle path between A and B
497+
for paths in found_paths_between_AB:
498+
for path in paths:
499+
changeFlag = True
500+
if verbose:
501+
print("Find uncovered circle path between A and B: " + graph.get_edge(node_A, node_B).__str__())
502+
edge = graph.get_edge(node_A, node_B)
503+
graph.remove_edge(edge)
504+
graph.add_edge(Edge(node_A, node_B, Endpoint.TAIL, Endpoint.TAIL))
505+
orient_on_path_helper(path, node_A, node_B)
506+
507+
return changeFlag
508+
509+
def ruleR6(graph: Graph, changeFlag: bool,
510+
verbose: bool = False) -> bool:
511+
nodes = graph.get_nodes()
512+
513+
for node_B in nodes:
514+
# Find A - B
515+
intoBTails = graph.get_nodes_into(node_B, Endpoint.TAIL)
516+
exist = False
517+
for node_A in intoBTails:
518+
if graph.get_endpoint(node_B, node_A) == Endpoint.TAIL:
519+
exist = True
520+
if not exist:
521+
continue
522+
# Find B o-*C
523+
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)
524+
for node_C in intoBCircles:
525+
changeFlag = True
526+
edge = graph.get_edge(node_B, node_C)
527+
graph.remove_edge(edge)
528+
graph.add_edge(Edge(node_B, node_C, Endpoint.TAIL, edge.get_proximal_endpoint(node_C)))
529+
if verbose:
530+
print("Orienting edge by rule 6): " + graph.get_edge(node_B, node_C).__str__())
531+
532+
return changeFlag
533+
534+
535+
def ruleR7(graph: Graph, changeFlag: bool,
536+
verbose: bool = False) -> bool:
537+
nodes = graph.get_nodes()
538+
539+
for node_B in nodes:
540+
# Find A -o B
541+
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)
542+
node_A_list = [node for node in intoBCircles if graph.get_endpoint(node_B, node) == Endpoint.TAIL]
543+
544+
# Find B o-*C
545+
for node_C in intoBCircles:
546+
# pdb.set_trace()
547+
for node_A in node_A_list:
548+
# pdb.set_trace()
549+
if not graph.is_adjacent_to(node_A, node_C):
550+
changeFlag = True
551+
edge = graph.get_edge(node_B, node_C)
552+
graph.remove_edge(edge)
553+
graph.add_edge(Edge(node_B, node_C, Endpoint.TAIL, edge.get_proximal_endpoint(node_C)))
554+
if verbose:
555+
print("Orienting edge by rule 7): " + graph.get_edge(node_B, node_C).__str__())
556+
return changeFlag
374557

375558
def getPath(node_c: Node, previous) -> List[Node]:
376559
l = []
@@ -544,9 +727,8 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m
544727

545728

546729

547-
def rule8(graph: Graph, nodes: List[Node]):
548-
nodes = graph.get_nodes()
549-
changeFlag = False
730+
def rule8(graph: Graph, nodes: List[Node], changeFlag):
731+
nodes = graph.get_nodes() if nodes is None else nodes
550732
for node_B in nodes:
551733
adj = graph.get_adjacent_nodes(node_B)
552734
if len(adj) < 2:
@@ -601,9 +783,9 @@ def find_possible_children(graph: Graph, parent_node, en_nodes=None):
601783

602784
return potential_child_nodes
603785

604-
def rule9(graph: Graph, nodes: List[Node]):
605-
changeFlag = False
606-
nodes = graph.get_nodes()
786+
def rule9(graph: Graph, nodes: List[Node], changeFlag):
787+
# changeFlag = False
788+
nodes = graph.get_nodes() if nodes is None else nodes
607789
for node_C in nodes:
608790
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
609791
for node_A in intoCArrows:
@@ -629,8 +811,8 @@ def rule9(graph: Graph, nodes: List[Node]):
629811
return changeFlag
630812

631813

632-
def rule10(graph: Graph):
633-
changeFlag = False
814+
def rule10(graph: Graph, changeFlag):
815+
# changeFlag = False
634816
nodes = graph.get_nodes()
635817
for node_C in nodes:
636818
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
@@ -895,6 +1077,7 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
8951077
graph, sep_sets, test_results = fas(dataset, nodes, independence_test_method=independence_test_method, alpha=alpha,
8961078
knowledge=background_knowledge, depth=depth, verbose=verbose, show_progress=show_progress)
8971079

1080+
# pdb.set_trace()
8981081
reorientAllWith(graph, Endpoint.CIRCLE)
8991082

9001083
rule0(graph, nodes, sep_sets, background_knowledge, verbose)
@@ -925,12 +1108,22 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
9251108
if verbose:
9261109
print("Epoch")
9271110

1111+
# rule 5
1112+
change_flag = ruleR5(graph, change_flag, verbose)
1113+
1114+
# rule 6
1115+
change_flag = ruleR6(graph, change_flag, verbose)
1116+
1117+
# rule 7
1118+
change_flag = ruleR7(graph, change_flag, verbose)
1119+
9281120
# rule 8
929-
change_flag = rule8(graph,nodes)
1121+
change_flag = rule8(graph,nodes, change_flag)
1122+
9301123
# rule 9
931-
change_flag = rule9(graph, nodes)
1124+
change_flag = rule9(graph, nodes, change_flag)
9321125
# rule 10
933-
change_flag = rule10(graph)
1126+
change_flag = rule10(graph, change_flag)
9341127

9351128
graph.set_pag(True)
9361129

causallearn/utils/DAG2PAG.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from causallearn.graph.Endpoint import Endpoint
1111
from causallearn.graph.GeneralGraph import GeneralGraph
1212
from causallearn.graph.Node import Node
13-
from causallearn.search.ConstraintBased.FCI import rule0, rulesR1R2cycle, ruleR3, ruleR4B
13+
from causallearn.search.ConstraintBased.FCI import rule0, rulesR1R2cycle, ruleR3, ruleR4B, ruleR5, ruleR6, ruleR7, rule8, rule9, rule10
1414
from causallearn.utils.cit import CIT, d_separation
1515

16-
def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
16+
17+
def dag2pag(dag: Dag, islatent: List[Node], isselection: List[Node] = []) -> GeneralGraph:
1718
"""
1819
Convert a DAG to its corresponding PAG
1920
Parameters
@@ -27,8 +28,8 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
2728
dg = nx.DiGraph()
2829
true_dag = nx.DiGraph()
2930
nodes = dag.get_nodes()
30-
observed_nodes = list(set(nodes) - set(islatent))
31-
mod_nodes = observed_nodes + islatent
31+
observed_nodes = list(set(nodes) - set(islatent) - set(isselection))
32+
mod_nodes = observed_nodes + islatent + isselection
3233
nodes = dag.get_nodes()
3334
nodes_ids = {node: i for i, node in enumerate(nodes)}
3435
mod_nodeids = {node: i for i, node in enumerate(mod_nodes)}
@@ -65,7 +66,7 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
6566
for Z in combinations(observed_nodes, l):
6667
if nodex in Z or nodey in Z:
6768
continue
68-
if d_separated(dg, {nodes_ids[nodex]}, {nodes_ids[nodey]}, set(nodes_ids[z] for z in Z)):
69+
if d_separated(dg, {nodes_ids[nodex]}, {nodes_ids[nodey]}, set(nodes_ids[z] for z in Z) | set([nodes_ids[s] for s in isselection])):
6970
if edge:
7071
PAG.remove_edge(edge)
7172
sepset[(nodes_ids[nodex], nodes_ids[nodey])] |= set(Z)
@@ -105,6 +106,13 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
105106
change_flag = ruleR4B(PAG, -1, data, independence_test_method, 0.05, sep_sets=sepset_reindexed,
106107
change_flag=change_flag,
107108
bk=None, verbose=False)
109+
change_flag = ruleR5(PAG, changeFlag=change_flag, verbose=True)
110+
change_flag = ruleR6(PAG, changeFlag=change_flag)
111+
change_flag = ruleR7(PAG, changeFlag=change_flag)
112+
change_flag = rule8(PAG, nodes=observed_nodes, changeFlag=change_flag)
113+
change_flag = rule9(PAG, nodes=observed_nodes, changeFlag=change_flag)
114+
change_flag = rule10(PAG, changeFlag=change_flag)
115+
108116
return PAG
109117

110118

tests/TestDAG2PAG.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,17 @@ def test_case3(self):
7272
print(pag)
7373
graphviz_pag = GraphUtils.to_pgv(pag)
7474
graphviz_pag.draw("pag.png", prog='dot', format='png')
75+
76+
def test_case_selection(self):
77+
nodes = []
78+
for i in range(5):
79+
nodes.append(GraphNode(str(i)))
80+
dag = Dag(nodes)
81+
dag.add_directed_edge(nodes[0], nodes[1])
82+
dag.add_directed_edge(nodes[1], nodes[2])
83+
dag.add_directed_edge(nodes[2], nodes[3])
84+
# Selection nodes
85+
dag.add_directed_edge(nodes[3], nodes[4])
86+
dag.add_directed_edge(nodes[0], nodes[4])
87+
pag = dag2pag(dag, islatent=[], isselection=[nodes[4]])
88+
print(pag)

tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_alarm_fci_chisq_0.05.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
44
0 0 0 0 2 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
55
0 -1 -1 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
6-
2 0 0 0 2 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
6+
2 0 0 0 -1 0 -1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
77
0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -1 0
88
0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
99
0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0
@@ -12,17 +12,17 @@
1212
0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0
1313
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
1414
0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
15-
0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 2
15+
0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -1 0 0 -1
1616
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0
1717
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0
1818
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0
1919
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2020
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 -1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0
2121
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0
2222
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0
23-
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0
23+
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 -1 0 0 0 0 0 0 0 0 0 0 0 0 0
2424
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0
25-
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -1 0 0 0 0 0 2 0 2 0 0 0 0 2 -1 0 0 0 0 0
25+
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -1 0 0 0 0 0 -1 0 -1 0 0 0 0 2 -1 0 0 0 0 0
2626
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0
2727
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 -1 0 0 0 0 0 0 0
2828
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0

0 commit comments

Comments
 (0)