diff --git a/causallearn/utils/GraphUtils.py b/causallearn/utils/GraphUtils.py index 9a9497d..34e75ac 100644 --- a/causallearn/utils/GraphUtils.py +++ b/causallearn/utils/GraphUtils.py @@ -536,10 +536,9 @@ def to_pydot(G: Graph, edges: List[Edge] | None = None, labels: List[str] | None pydot_g = pydot.Dot(title, graph_type="digraph", fontsize=18) pydot_g.obj_dict["attributes"]["dpi"] = dpi - nodes = G.get_nodes() + for i, node in enumerate(nodes): node_name = labels[i] if labels is not None else node.get_name() - pydot_g.add_node(pydot.Node(i, label=node.get_name())) if node.get_node_type() == NodeType.LATENT: pydot_g.add_node(pydot.Node(i, label=node_name, shape='square')) else: