Skip to content

Commit 2b69ec8

Browse files
committed
Add smoking test
1 parent d6fefbc commit 2b69ec8

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

tests/why_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ def _validate_it(why, test_data, check_score=True):
4242
print("rloss:", score)
4343

4444

45+
def test_smoke():
46+
from ylearn.api.smoke import smoke
47+
smoke()
48+
49+
4550
def test_basis():
4651
data, test_data, outcome, treatment, adjustment, covariate = _dgp.generate_data_x1b_y1()
4752
why = Why()

ylearn/api/smoke.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from ylearn.api import Why
2+
from ylearn.exp_dataset.exp_data import single_binary_treatment
3+
4+
5+
def smoke(estimator='auto'):
6+
print('-' * 20, 'smoke with estimator', estimator, '-' * 20)
7+
8+
train, val, _ = single_binary_treatment()
9+
te = train.pop('TE')
10+
te = val.pop('TE')
11+
adjustment = [c for c in train.columns.tolist() if c.startswith('w')]
12+
covariate = [c for c in train.columns.tolist() if c.startswith('c')]
13+
14+
if estimator == 'grf':
15+
covariate.extend(adjustment)
16+
adjustment = None
17+
18+
why = Why(estimator=estimator)
19+
why.fit(train, outcome='outcome', treatment='treatment', adjustment=adjustment, covariate=covariate)
20+
21+
cate = why.causal_effect(val)
22+
print('CATE:\n', cate)
23+
24+
auuc = why.score(val, scorer='auuc')
25+
print('AUUC', auuc)
26+
27+
28+
if __name__ == '__main__':
29+
from ylearn.utils import logging
30+
31+
logging.set_level('info')
32+
for est in ['slearner', 'tlearner', 'xlearner', 'dr', 'dml', 'tree', 'grf']:
33+
smoke(est)
34+
35+
print('\n<done>')

0 commit comments

Comments
 (0)