2
2
3
3
import warnings
4
4
from queue import Queue
5
- from typing import List , Set , Tuple , Dict
5
+ from typing import List , Set , Tuple , Dict , Generator
6
6
from numpy import ndarray
7
7
8
8
from causallearn .graph .Edge import Edge
17
17
from causallearn .utils .PCUtils .BackgroundKnowledge import BackgroundKnowledge
18
18
from itertools import combinations
19
19
20
+ def is_uncovered_path (nodes : List [Node ], G : Graph ) -> bool :
21
+ """
22
+ Determines whether the given path is an uncovered path in this graph.
23
+
24
+ A path is an uncovered path if no two nonconsecutive nodes (Vi-1 and Vi+1) in the path are
25
+ adjacent.
26
+ """
27
+ for i in range (len (nodes ) - 2 ):
28
+ if G .is_adjacent_to (nodes [i ], nodes [i + 2 ]):
29
+ return False
30
+ return True
31
+
32
+
20
33
def traverseSemiDirected (node : Node , edge : Edge ) -> Node | None :
21
34
if node == edge .get_node1 ():
22
35
if edge .get_endpoint1 () == Endpoint .TAIL or edge .get_endpoint1 () == Endpoint .CIRCLE :
@@ -26,8 +39,17 @@ def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
26
39
return edge .get_node1 ()
27
40
return None
28
41
42
+ def traverseCircle (node : Node , edge : Edge ) -> Node | None :
43
+ if node == edge .get_node1 ():
44
+ if edge .get_endpoint1 () == Endpoint .CIRCLE and edge .get_endpoint2 () == Endpoint .CIRCLE :
45
+ return edge .get_node2 ()
46
+ elif node == edge .get_node2 ():
47
+ if edge .get_endpoint1 () == Endpoint .CIRCLE and edge .get_endpoint2 () == Endpoint .CIRCLE :
48
+ return edge .get_node1 ()
49
+ return None
50
+
29
51
30
- def existsSemiDirectedPath (node_from : Node , node_to : Node , G : Graph ) -> bool :
52
+ def existsSemiDirectedPath (node_from : Node , node_to : Node , G : Graph ) -> bool : ## TODO: Now it does not detect whether the path is an uncovered path
31
53
Q = Queue ()
32
54
V = set ()
33
55
@@ -60,6 +82,42 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool:
60
82
61
83
return False
62
84
85
+ def GetUncoveredCirclePath (node_from : Node , node_to : Node , G : Graph , exclude_node : List [Node ]) -> Generator [Node ] | None :
86
+ Q = Queue ()
87
+ V = set ()
88
+
89
+ path = [node_from ]
90
+
91
+ for node_u in G .get_adjacent_nodes (node_from ):
92
+ if node_u in exclude_node :
93
+ continue
94
+ edge = G .get_edge (node_from , node_u )
95
+ node_c = traverseCircle (node_from , edge )
96
+
97
+ if node_c is None or node_c in exclude_node :
98
+ continue
99
+
100
+ if not V .__contains__ (node_c ):
101
+ V .add (node_c )
102
+ Q .put ((node_c , path + [node_c ]))
103
+
104
+ while not Q .empty ():
105
+ node_t , path = Q .get_nowait ()
106
+ if node_t == node_to and is_uncovered_path (path , G ):
107
+ yield path
108
+
109
+ for node_u in G .get_adjacent_nodes (node_t ):
110
+ edge = G .get_edge (node_t , node_u )
111
+ node_c = traverseCircle (node_t , edge )
112
+
113
+ if node_c is None or node_c in exclude_node :
114
+ continue
115
+
116
+ if not V .__contains__ (node_c ):
117
+ V .add (node_c )
118
+ Q .put ((node_c , path + [node_c ]))
119
+
120
+
63
121
64
122
def existOnePathWithPossibleParents (previous , node_w : Node , node_x : Node , node_b : Node , graph : Graph ) -> bool :
65
123
if node_w == node_x :
@@ -371,6 +429,131 @@ def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: Backgrou
371
429
changeFlag = True
372
430
return changeFlag
373
431
432
+ def ruleR5 (graph : Graph , changeFlag : bool ,
433
+ verbose : bool = False ) -> bool :
434
+ """
435
+ Rule R5 of the FCI algorithm.
436
+ by Jiji Zhang, 2008, "On the completeness of orientation rules for causal discovery in the presence of latent confounders and selection bias"]
437
+
438
+ This function orients any edge that is part of an uncovered circle path between two nodes A and B,
439
+ if such a path exists. The path must start and end with a circle edge and must be uncovered, i.e. the
440
+ nodes on the path must not be adjacent to A or B. The orientation of the edges on the path is set to
441
+ double tail.
442
+ """
443
+ nodes = graph .get_nodes ()
444
+ def orient_on_path_helper (path , node_A , node_B ):
445
+ # orient A - C, D - B
446
+ edge = graph .get_edge (node_A , path [0 ])
447
+ graph .remove_edge (edge )
448
+ graph .add_edge (Edge (node_A , path [0 ], Endpoint .TAIL , Endpoint .TAIL ))
449
+
450
+ edge = graph .get_edge (node_B , path [- 1 ])
451
+ graph .remove_edge (edge )
452
+ graph .add_edge (Edge (node_B , path [- 1 ], Endpoint .TAIL , Endpoint .TAIL ))
453
+ if verbose :
454
+ print ("Orienting edge A - C (Double tail): " + graph .get_edge (node_A , path [0 ]).__str__ ())
455
+ print ("Orienting edge B - D (Double tail): " + graph .get_edge (node_B , path [- 1 ]).__str__ ())
456
+
457
+ # orient everything on the path to both tails
458
+ for i in range (len (path ) - 1 ):
459
+ edge = graph .get_edge (path [i ], path [i + 1 ])
460
+ graph .remove_edge (edge )
461
+ graph .add_edge (Edge (path [i ], path [i + 1 ], Endpoint .TAIL , Endpoint .TAIL ))
462
+ if verbose :
463
+ print ("Orienting edge (Double tail): " + graph .get_edge (path [i ], path [i + 1 ]).__str__ ())
464
+
465
+ for node_B in nodes :
466
+ intoBCircles = graph .get_nodes_into (node_B , Endpoint .CIRCLE )
467
+
468
+ for node_A in intoBCircles :
469
+ found_paths_between_AB = []
470
+ if graph .get_endpoint (node_B , node_A ) != Endpoint .CIRCLE :
471
+ continue
472
+ else :
473
+ # Check if there is an uncovered circle path between A and B (A o-o C .. D o-o B)
474
+ # s.t. A is not adjacent to D and B is not adjacent to C
475
+ a_node_idx = graph .node_map [node_A ]
476
+ b_node_idx = graph .node_map [node_B ]
477
+ a_adj_nodes = graph .get_adjacent_nodes (node_A )
478
+ b_adj_nodes = graph .get_adjacent_nodes (node_B )
479
+
480
+ # get the adjacent nodes with circle edges of A and B
481
+ a_circle_adj_nodes_set = [node for node in a_adj_nodes if graph .node_map [node ] != a_node_idx and graph .node_map [node ]!= b_node_idx
482
+ and graph .get_endpoint (node , node_A ) == Endpoint .CIRCLE and graph .get_endpoint (node_A , node ) == Endpoint .CIRCLE ]
483
+ b_circle_adj_nodes_set = [node for node in b_adj_nodes if graph .node_map [node ] != a_node_idx and graph .node_map [node ]!= b_node_idx
484
+ and graph .get_endpoint (node , node_B ) == Endpoint .CIRCLE and graph .get_endpoint (node_B , node ) == Endpoint .CIRCLE ]
485
+
486
+ # get the adjacent nodes with circle edges of A and B that is non adjacent to B and A, respectively
487
+ for node_C in a_circle_adj_nodes_set :
488
+ if graph .is_adjacent_to (node_B , node_C ):
489
+ continue
490
+ for node_D in b_circle_adj_nodes_set :
491
+ if graph .is_adjacent_to (node_A , node_D ):
492
+ continue
493
+ paths = GetUncoveredCirclePath (node_from = node_C , node_to = node_D , G = graph , exclude_node = [node_A , node_B ]) # get the uncovered circle path between C and D, excluding A and B
494
+ found_paths_between_AB .append (paths )
495
+
496
+ # Orient the uncovered circle path between A and B
497
+ for paths in found_paths_between_AB :
498
+ for path in paths :
499
+ changeFlag = True
500
+ if verbose :
501
+ print ("Find uncovered circle path between A and B: " + graph .get_edge (node_A , node_B ).__str__ ())
502
+ edge = graph .get_edge (node_A , node_B )
503
+ graph .remove_edge (edge )
504
+ graph .add_edge (Edge (node_A , node_B , Endpoint .TAIL , Endpoint .TAIL ))
505
+ orient_on_path_helper (path , node_A , node_B )
506
+
507
+ return changeFlag
508
+
509
+ def ruleR6 (graph : Graph , changeFlag : bool ,
510
+ verbose : bool = False ) -> bool :
511
+ nodes = graph .get_nodes ()
512
+
513
+ for node_B in nodes :
514
+ # Find A - B
515
+ intoBTails = graph .get_nodes_into (node_B , Endpoint .TAIL )
516
+ exist = False
517
+ for node_A in intoBTails :
518
+ if graph .get_endpoint (node_B , node_A ) == Endpoint .TAIL :
519
+ exist = True
520
+ if not exist :
521
+ continue
522
+ # Find B o-*C
523
+ intoBCircles = graph .get_nodes_into (node_B , Endpoint .CIRCLE )
524
+ for node_C in intoBCircles :
525
+ changeFlag = True
526
+ edge = graph .get_edge (node_B , node_C )
527
+ graph .remove_edge (edge )
528
+ graph .add_edge (Edge (node_B , node_C , Endpoint .TAIL , edge .get_proximal_endpoint (node_C )))
529
+ if verbose :
530
+ print ("Orienting edge by rule 6): " + graph .get_edge (node_B , node_C ).__str__ ())
531
+
532
+ return changeFlag
533
+
534
+
535
+ def ruleR7 (graph : Graph , changeFlag : bool ,
536
+ verbose : bool = False ) -> bool :
537
+ nodes = graph .get_nodes ()
538
+
539
+ for node_B in nodes :
540
+ # Find A -o B
541
+ intoBCircles = graph .get_nodes_into (node_B , Endpoint .CIRCLE )
542
+ node_A_list = [node for node in intoBCircles if graph .get_endpoint (node_B , node ) == Endpoint .TAIL ]
543
+
544
+ # Find B o-*C
545
+ for node_C in intoBCircles :
546
+ # pdb.set_trace()
547
+ for node_A in node_A_list :
548
+ # pdb.set_trace()
549
+ if not graph .is_adjacent_to (node_A , node_C ):
550
+ changeFlag = True
551
+ edge = graph .get_edge (node_B , node_C )
552
+ graph .remove_edge (edge )
553
+ graph .add_edge (Edge (node_B , node_C , Endpoint .TAIL , edge .get_proximal_endpoint (node_C )))
554
+ if verbose :
555
+ print ("Orienting edge by rule 7): " + graph .get_edge (node_B , node_C ).__str__ ())
556
+ return changeFlag
374
557
375
558
def getPath (node_c : Node , previous ) -> List [Node ]:
376
559
l = []
@@ -544,9 +727,8 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m
544
727
545
728
546
729
547
- def rule8 (graph : Graph , nodes : List [Node ]):
548
- nodes = graph .get_nodes ()
549
- changeFlag = False
730
+ def rule8 (graph : Graph , nodes : List [Node ], changeFlag ):
731
+ nodes = graph .get_nodes () if nodes is None else nodes
550
732
for node_B in nodes :
551
733
adj = graph .get_adjacent_nodes (node_B )
552
734
if len (adj ) < 2 :
@@ -601,9 +783,9 @@ def find_possible_children(graph: Graph, parent_node, en_nodes=None):
601
783
602
784
return potential_child_nodes
603
785
604
- def rule9 (graph : Graph , nodes : List [Node ]):
605
- changeFlag = False
606
- nodes = graph .get_nodes ()
786
+ def rule9 (graph : Graph , nodes : List [Node ], changeFlag ):
787
+ # changeFlag = False
788
+ nodes = graph .get_nodes () if nodes is None else nodes
607
789
for node_C in nodes :
608
790
intoCArrows = graph .get_nodes_into (node_C , Endpoint .ARROW )
609
791
for node_A in intoCArrows :
@@ -629,8 +811,8 @@ def rule9(graph: Graph, nodes: List[Node]):
629
811
return changeFlag
630
812
631
813
632
- def rule10 (graph : Graph ):
633
- changeFlag = False
814
+ def rule10 (graph : Graph , changeFlag ):
815
+ # changeFlag = False
634
816
nodes = graph .get_nodes ()
635
817
for node_C in nodes :
636
818
intoCArrows = graph .get_nodes_into (node_C , Endpoint .ARROW )
@@ -895,6 +1077,7 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
895
1077
graph , sep_sets , test_results = fas (dataset , nodes , independence_test_method = independence_test_method , alpha = alpha ,
896
1078
knowledge = background_knowledge , depth = depth , verbose = verbose , show_progress = show_progress )
897
1079
1080
+ # pdb.set_trace()
898
1081
reorientAllWith (graph , Endpoint .CIRCLE )
899
1082
900
1083
rule0 (graph , nodes , sep_sets , background_knowledge , verbose )
@@ -925,12 +1108,22 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
925
1108
if verbose :
926
1109
print ("Epoch" )
927
1110
1111
+ # rule 5
1112
+ change_flag = ruleR5 (graph , change_flag , verbose )
1113
+
1114
+ # rule 6
1115
+ change_flag = ruleR6 (graph , change_flag , verbose )
1116
+
1117
+ # rule 7
1118
+ change_flag = ruleR7 (graph , change_flag , verbose )
1119
+
928
1120
# rule 8
929
- change_flag = rule8 (graph ,nodes )
1121
+ change_flag = rule8 (graph ,nodes , change_flag )
1122
+
930
1123
# rule 9
931
- change_flag = rule9 (graph , nodes )
1124
+ change_flag = rule9 (graph , nodes , change_flag )
932
1125
# rule 10
933
- change_flag = rule10 (graph )
1126
+ change_flag = rule10 (graph , change_flag )
934
1127
935
1128
graph .set_pag (True )
936
1129
0 commit comments