@@ -29,7 +29,7 @@ def __init__(
29
29
self .parameters = parameters
30
30
self .score_cache = {}
31
31
32
- if self .local_score_fun == local_score_BIC_from_cov :
32
+ if self .local_score_fun . __name__ == ' local_score_BIC_from_cov' :
33
33
self .cov = np .cov (self .data .T )
34
34
self .n = self .data .shape [0 ]
35
35
@@ -40,15 +40,15 @@ def score(self, i: int, PAi: List[int]) -> float:
40
40
hash_key = tuple (sorted (PAi ))
41
41
42
42
if not self .score_cache [i ].__contains__ (hash_key ):
43
- if self .local_score_fun == local_score_BIC_from_cov :
43
+ if self .local_score_fun . __name__ == ' local_score_BIC_from_cov' :
44
44
self .score_cache [i ][hash_key ] = self .local_score_fun ((self .cov , self .n ), i , PAi , self .parameters )
45
45
else :
46
46
self .score_cache [i ][hash_key ] = self .local_score_fun (self .data , i , PAi , self .parameters )
47
47
48
48
return self .score_cache [i ][hash_key ]
49
49
50
50
def score_nocache (self , i : int , PAi : List [int ]) -> float :
51
- if self .local_score_fun == local_score_BIC_from_cov :
51
+ if self .local_score_fun . __name__ == ' local_score_BIC_from_cov' :
52
52
return self .local_score_fun ((self .cov , self .n ), i , PAi , self .parameters )
53
53
else :
54
- return self .local_score_fun (self .data , i , PAi , self .parameters )
54
+ return self .local_score_fun (self .data , i , PAi , self .parameters )
0 commit comments