@@ -32,18 +32,10 @@ def __init__(self,
32
32
self .node_name = node_name
33
33
self .parent_node = parent_node
34
34
self .child_nodes = child_nodes
35
- self .node_data = None
35
+ self .original_child_nodes = child_nodes
36
36
37
- def __call__ (self , persistent = False ):
38
-
39
- if persistent :
40
- if self .node_data is None :
41
- self .node_data = pickle .load (open (self .node_path , "rb" ))
42
-
43
- return self .node_data
44
-
45
- else :
46
- return pickle .load (open (self .node_path , "rb" ))
37
+ def __call__ (self ):
38
+ return pickle .load (open (self .node_path , "rb" ))
47
39
48
40
49
41
class Node ():
@@ -78,6 +70,7 @@ def __init__(self,
78
70
self .parent_node_k = parent_node_k
79
71
self .parent_node_name = parent_node_name
80
72
self .child_node_names = child_node_names
73
+ self .original_child_node_names = child_node_names
81
74
self .original_indices = original_indices
82
75
self .num_samples = num_samples
83
76
self .leaf = leaf
@@ -109,7 +102,8 @@ def __init__(self,
109
102
n_nodes = 1 ,
110
103
verbose = True ,
111
104
comm_buff_size = 10000000 ,
112
- random_identifiers = False
105
+ random_identifiers = False ,
106
+ root_node_name = "Root"
113
107
):
114
108
"""
115
109
HNMFk is a Hierarchical Non-negative Matrix Factorization module with the capability to do automatic model determination.
@@ -154,6 +148,8 @@ def __init__(self,
154
148
If True, it prints progress. The default is True.
155
149
random_identifiers : bool, optional
156
150
If True, model will use randomly generated strings as the identifiers of the nodes. Otherwise, it will use the k for ancestry naming convention.
151
+ root_node_name : str, optional
152
+ Naming convention to be used when saving the root name. Default is "Root".
157
153
Returns
158
154
-------
159
155
None.
@@ -174,6 +170,7 @@ def __init__(self,
174
170
self .verbose = verbose
175
171
self .comm_buff_size = comm_buff_size
176
172
self .random_identifiers = random_identifiers
173
+ self .root_node_name = root_node_name
177
174
178
175
organized_nmfk_params = []
179
176
for params in nmfk_params :
@@ -309,7 +306,7 @@ def fit(self, X, Ks, from_checkpoint=False, save_checkpoint=True):
309
306
if self .random_identifiers :
310
307
self .root_name = str (uuid .uuid1 ())
311
308
else :
312
- self .root_name = "*"
309
+ self .root_name = self . root_node_name
313
310
314
311
self .target_jobs [self .root_name ] = {
315
312
"parent_node_name" :"None" ,
@@ -726,6 +723,110 @@ def traverse_nodes(self):
726
723
727
724
return return_data
728
725
726
+ def traverse_tiny_leaf_topics (self , threshold = 5 ):
727
+ """
728
+ Graph iterator with thresholding on number of documents. Returns a list of nodes where number of documents are less than the threshold.\n
729
+ This operation is online, only the nodes that are outliers based on the number of documents are kept in the memory.
730
+
731
+ Parameters
732
+ ----------
733
+ threshold : int
734
+ Minimum number of documents each node should have.
735
+
736
+ Returns
737
+ -------
738
+ data : list
739
+ List of dictionarys that are format of node for each entry in the list.
740
+
741
+ """
742
+ self ._all_nodes = []
743
+ self ._get_traversal (self .root , small_docs_thresh = threshold )
744
+ return_data = self ._all_nodes .copy ()
745
+ self ._all_nodes = []
746
+
747
+ return return_data
748
+
749
+ def get_tiny_leaf_topics (self ):
750
+ """
751
+ Graph iterator for tiny documents if processed already with self.process_tiny_leaf_topics(threshold:int).\n
752
+
753
+ Returns
754
+ -------
755
+ tiny_leafs : list
756
+ List of dictionarys that are format of node for each entry in the list.
757
+
758
+ """
759
+ try :
760
+ return pickle .load (open (os .path .join (self .experiment_name , "tiny_leafs.p" ), "rb" ))
761
+ except Exception as e :
762
+ print ("Could not load the tiny leafs. Did you call process_tiny_leaf_topics(threshold:int)?" , e )
763
+ return None
764
+
765
+ def process_tiny_leaf_topics (self , threshold = 5 ):
766
+ """
767
+ Graph post-processing with thresholding on number of documents.\n
768
+ Returns a list of all tiny nodes, with all the nodes that had number of documents less than the threshold.\n
769
+ Removes these outlier nodes from child-node lists on the original graph from their parents.\n
770
+ Graph is re-set each time this function is called such that original child nodes are re-assigned.\n
771
+ If threshold=None, this function will re-assign the original child indices only, and return None.
772
+
773
+ Parameters
774
+ ----------
775
+ threshold : int
776
+ Minimum number of documents each node should have.
777
+
778
+ Returns
779
+ -------
780
+ tiny_leafs : list
781
+ List of dictionarys that are format of node for each entry in the list.
782
+
783
+ """
784
+
785
+ # set the old child nodes on each node
786
+ self ._update_child_nodes_traversal (self .root )
787
+
788
+ # remove the old saved tiny leafs
789
+ try :
790
+ os .remove (os .path .join (self .experiment_name , "tiny_leafs.p" ))
791
+ except :
792
+ pass
793
+
794
+ # if threshold is none, we reversed everything
795
+ if threshold is None :
796
+ return
797
+
798
+ tiny_leafs = self .traverse_tiny_leaf_topics (threshold = threshold )
799
+ pickle .dump (tiny_leafs , open (os .path .join (self .experiment_name , "tiny_leafs.p" ), "wb" ))
800
+
801
+ # remove tinly leafs from its parents
802
+ for tf in tiny_leafs :
803
+ my_name = tf ["node_name" ]
804
+ parent_name = tf ["parent_node_name" ]
805
+ parent_node = self ._search_traversal (self .root , parent_name )
806
+
807
+ # remove from online iterator
808
+ parent_node .child_nodes = [node for node in parent_node .child_nodes if node .node_name != my_name ]
809
+
810
+ # also need to remove from saved node data
811
+ parent_node_loaded = parent_node ()
812
+ parent_node_loaded .child_node_names = [node_name for node_name in parent_node_loaded .child_node_names if node_name != my_name ]
813
+ pickle .dump (parent_node_loaded , open (os .path .join (self .experiment_name , * parent_node_loaded .node_save_path .split (os .sep )[1 :]), "wb" ))
814
+
815
+ return tiny_leafs
816
+
817
+ def _update_child_nodes_traversal (self , node ):
818
+
819
+ for nn in node .original_child_nodes :
820
+ self ._update_child_nodes_traversal (nn )
821
+
822
+ if node .child_nodes != node .original_child_nodes :
823
+ node .child_nodes = node .original_child_nodes
824
+
825
+ node_loaded = node ()
826
+ if node_loaded .original_child_node_names != node_loaded .child_node_names :
827
+ node_loaded .child_node_names = node_loaded .original_child_node_names
828
+ pickle .dump (node_loaded , open (os .path .join (self .experiment_name , * node_loaded .node_save_path .split (os .sep )[1 :]), "wb" ))
829
+
729
830
def _search_traversal (self , node , name ):
730
831
731
832
# Base case: if the current node matches the target name
@@ -743,12 +844,17 @@ def _search_traversal(self, node, name):
743
844
# If the node is not found in this branch, return None
744
845
return None
745
846
746
- def _get_traversal (self , node ):
847
+ def _get_traversal (self , node , small_docs_thresh = None ):
747
848
748
849
for nn in node .child_nodes :
749
- self ._get_traversal (nn )
850
+ self ._get_traversal (nn , small_docs_thresh = small_docs_thresh )
851
+
852
+ if small_docs_thresh is not None :
853
+ tmp_node_data = vars (node ()).copy ()
854
+ if not (tmp_node_data ["leaf" ] and tmp_node_data ["num_samples" ] < small_docs_thresh ):
855
+ return
750
856
751
- data = vars (node (persistent = True )).copy ()
857
+ data = vars (node ()).copy ()
752
858
data ["node_save_path" ] = os .path .join (self .experiment_name , * data ["node_save_path" ].split (os .sep )[1 :])
753
859
if data ["node_name" ] != self .root_name :
754
860
data ["parent_node_save_path" ] = os .path .join (self .experiment_name , * data ["parent_node_save_path" ].split (os .sep )[1 :])
0 commit comments