From 778843190410e5f654d3ede0148c425890dd636a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 18 Jul 2024 15:25:33 -0400 Subject: [PATCH 1/3] fix: fix PyTorch model extension in simplify Signed-off-by: Jinzhe Zeng --- dpgen/simplify/simplify.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dpgen/simplify/simplify.py b/dpgen/simplify/simplify.py index 24205fda3..30b3472ac 100644 --- a/dpgen/simplify/simplify.py +++ b/dpgen/simplify/simplify.py @@ -221,7 +221,9 @@ def run_model_devi(iter_index, jdata, mdata): commands = [] run_tasks = ["."] # get models - models = glob.glob(os.path.join(work_path, "graph*pb")) + suffix = _get_model_suffix(jdata) + models = glob.glob(os.path.join(work_path, f"graph*{suffix}")) + assert len(models) > 0, "No model file found." model_names = [os.path.basename(ii) for ii in models] task_model_list = [] for ii in model_names: From 4e09b3d8fc36c2b8505e18794ca13dd63944d396 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 18 Jul 2024 15:43:04 -0400 Subject: [PATCH 2/3] touch models in the tests Signed-off-by: Jinzhe Zeng --- tests/simplify/test_run_model_devi.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/simplify/test_run_model_devi.py b/tests/simplify/test_run_model_devi.py index e928afa8e..c1ea798d4 100644 --- a/tests/simplify/test_run_model_devi.py +++ b/tests/simplify/test_run_model_devi.py @@ -17,6 +17,9 @@ class TestOneH5(unittest.TestCase): def setUp(self): work_path = Path("iter.000000") / "01.model_devi" work_path.mkdir(parents=True, exist_ok=True) + # fake models + for ii in range(4): + (work_path / f"graph.{ii:03d}.pb").torch() with tempfile.TemporaryDirectory() as tmpdir: with open(Path(tmpdir) / "test.xyz", "w") as f: f.write( From d22f75419951523e57922c5c6ec449c5769994b7 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 18 Jul 2024 15:51:31 -0400 Subject: [PATCH 3/3] Update test_run_model_devi.py Signed-off-by: Jinzhe Zeng --- tests/simplify/test_run_model_devi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/simplify/test_run_model_devi.py b/tests/simplify/test_run_model_devi.py index c1ea798d4..28d5732e5 100644 --- a/tests/simplify/test_run_model_devi.py +++ b/tests/simplify/test_run_model_devi.py @@ -19,7 +19,7 @@ def setUp(self): work_path.mkdir(parents=True, exist_ok=True) # fake models for ii in range(4): - (work_path / f"graph.{ii:03d}.pb").torch() + (work_path / f"graph.{ii:03d}.pb").touch() with tempfile.TemporaryDirectory() as tmpdir: with open(Path(tmpdir) / "test.xyz", "w") as f: f.write(