Skip to content

Commit 1244ff7

Browse files
committed
fix test build models
1 parent 27fb2a8 commit 1244ff7

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

dptb/tests/test_build_model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def test_build_nnsk_from_scratch():
4646
"overlap": False,
4747
"seed": 3982377700
4848
}
49-
statistics = None
50-
model = build_model(None, model_options, common_options, statistics)
49+
model = build_model(None, model_options, common_options)
5150

5251
assert isinstance(model, NNSK)
5352
assert model.device == "cpu"
@@ -111,9 +110,8 @@ def test_build_model_MIX_from_scratch():
111110
"overlap": False,
112111
"seed": 3982377700
113112
}
114-
statistics = None
115113

116-
model = build_model(None, model_options, common_options, statistics)
114+
model = build_model(None, model_options, common_options)
117115

118116
assert isinstance(model, MIX)
119117
assert model.name == "mix"

dptb/tests/test_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_fromscratch_noref_noval(self):
4747
jdata = self.jdata
4848
train_datasets = self.train_datasets
4949
model = build_model(None, model_options=jdata["model_options"],
50-
common_options=jdata["common_options"], statistics=train_datasets.E3statistics())
50+
common_options=jdata["common_options"])
5151
trainer = Trainer(
5252
train_options=jdata["train_options"],
5353
common_options=jdata["common_options"],
@@ -73,7 +73,7 @@ def test_fromscratch_ref_noval(self):
7373
reference_datasets = build_dataset(**jdata["data_options"]["reference"], **jdata["common_options"])
7474

7575
model = build_model(None, model_options=jdata["model_options"],
76-
common_options=jdata["common_options"], statistics=train_datasets.E3statistics())
76+
common_options=jdata["common_options"])
7777

7878
trainer = Trainer(
7979
train_options=jdata["train_options"],
@@ -100,7 +100,7 @@ def test_fromscratch_noref_val(self):
100100
validation_datasets = build_dataset(**jdata["data_options"]["validation"], **jdata["common_options"])
101101

102102
model = build_model(None, model_options=jdata["model_options"],
103-
common_options=jdata["common_options"], statistics=train_datasets.E3statistics())
103+
common_options=jdata["common_options"])
104104

105105
trainer = Trainer(
106106
train_options=jdata["train_options"],
@@ -126,7 +126,7 @@ def test_initmodel_noref_nval(self):
126126
checkpoint = f"{rootdir}/test_sktb/output/test_valence/checkpoint/nnsk.best.pth"
127127
run_options.update({"init_model": checkpoint, "restart": None})
128128
model = build_model(checkpoint, model_options=jdata["model_options"],
129-
common_options=jdata["common_options"], statistics=train_datasets.E3statistics())
129+
common_options=jdata["common_options"])
130130
trainer = Trainer(
131131
train_options=jdata["train_options"],
132132
common_options=jdata["common_options"],

0 commit comments

Comments
 (0)