Skip to content

Commit 2a01a86

Browse files
authored
Merge pull request #216 from kunwuz/main
Update LocalScoreFunctionClass to fix issue calling local_score_BIC_from_cov
2 parents f6a96e3 + bc9b002 commit 2a01a86

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

causallearn/score/LocalScoreFunctionClass.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
self.parameters = parameters
3030
self.score_cache = {}
3131

32-
if self.local_score_fun == local_score_BIC_from_cov:
32+
if self.local_score_fun.__name__ == 'local_score_BIC_from_cov':
3333
self.cov = np.cov(self.data.T)
3434
self.n = self.data.shape[0]
3535

@@ -40,15 +40,15 @@ def score(self, i: int, PAi: List[int]) -> float:
4040
hash_key = tuple(sorted(PAi))
4141

4242
if not self.score_cache[i].__contains__(hash_key):
43-
if self.local_score_fun == local_score_BIC_from_cov:
43+
if self.local_score_fun.__name__ == 'local_score_BIC_from_cov':
4444
self.score_cache[i][hash_key] = self.local_score_fun((self.cov, self.n), i, PAi, self.parameters)
4545
else:
4646
self.score_cache[i][hash_key] = self.local_score_fun(self.data, i, PAi, self.parameters)
4747

4848
return self.score_cache[i][hash_key]
4949

5050
def score_nocache(self, i: int, PAi: List[int]) -> float:
51-
if self.local_score_fun == local_score_BIC_from_cov:
51+
if self.local_score_fun.__name__ == 'local_score_BIC_from_cov':
5252
return self.local_score_fun((self.cov, self.n), i, PAi, self.parameters)
5353
else:
54-
return self.local_score_fun(self.data, i, PAi, self.parameters)
54+
return self.local_score_fun(self.data, i, PAi, self.parameters)

causallearn/search/PermutationBased/BOSS.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
def boss(
2525
X: np.ndarray,
26-
score_func: str = "local_score_BIC",
26+
score_func: str = "local_score_BIC_from_cov",
2727
parameters: Optional[Dict[str, Any]] = None,
2828
verbose: Optional[bool] = True,
2929
node_names: Optional[List[str]] = None,

0 commit comments

Comments
 (0)