Skip to content

Commit 9011d75

Browse files
authored
Feature/SK-940 | CLI command for train and validate (#658)
1 parent fce112a commit 9011d75

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

fedn/cli/run_cmd.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import click
66
import yaml
7-
87
from fedn.common.exceptions import InvalidClientConfig
98
from fedn.common.log_config import logger
109
from fedn.network.clients.client import Client
@@ -44,7 +43,70 @@ def run_cmd(ctx):
4443
""":param ctx:
4544
"""
4645
pass
46+
@run_cmd.command("validate")
47+
@click.option("-p", "--path", required=True, help="Path to package directory containing fedn.yaml")
48+
@click.option("-i", "--input", required=True, help="Path to input model" )
49+
@click.option("-o", "--output", required=True,help="Path to write the output JSON containing validation metrics")
50+
@click.pass_context
51+
def validate_cmd(ctx, path,input,output):
52+
"""Execute 'validate' entrypoint in fedn.yaml.
53+
54+
:param ctx:
55+
:param path: Path to folder containing fedn.yaml
56+
:type path: str
57+
"""
58+
path = os.path.abspath(path)
59+
yaml_file = os.path.join(path, "fedn.yaml")
60+
if not os.path.exists(yaml_file):
61+
logger.error(f"Could not find fedn.yaml in {path}")
62+
exit(-1)
63+
64+
config = _read_yaml_file(yaml_file)
65+
# Check that validate is defined in fedn.yaml under entry_points
66+
if "validate" not in config["entry_points"]:
67+
logger.error("No validate command defined in fedn.yaml")
68+
exit(-1)
69+
70+
dispatcher = Dispatcher(config, path)
71+
_ = dispatcher._get_or_create_python_env()
72+
dispatcher.run_cmd("validate {} {}".format(input, output))
73+
74+
# delete the virtualenv
75+
if dispatcher.python_env_path:
76+
logger.info(f"Removing virtualenv {dispatcher.python_env_path}")
77+
shutil.rmtree(dispatcher.python_env_path)
78+
@run_cmd.command("train")
79+
@click.option("-p", "--path", required=True, help="Path to package directory containing fedn.yaml")
80+
@click.option("-i", "--input", required=True, help="Path to input model parameters" )
81+
@click.option("-o", "--output", required=True,help="Path to write the updated model parameters ")
82+
@click.pass_context
83+
def train_cmd(ctx, path,input,output):
84+
"""Execute 'train' entrypoint in fedn.yaml.
4785
86+
:param ctx:
87+
:param path: Path to folder containing fedn.yaml
88+
:type path: str
89+
"""
90+
path = os.path.abspath(path)
91+
yaml_file = os.path.join(path, "fedn.yaml")
92+
if not os.path.exists(yaml_file):
93+
logger.error(f"Could not find fedn.yaml in {path}")
94+
exit(-1)
95+
96+
config = _read_yaml_file(yaml_file)
97+
# Check that train is defined in fedn.yaml under entry_points
98+
if "train" not in config["entry_points"]:
99+
logger.error("No train command defined in fedn.yaml")
100+
exit(-1)
101+
102+
dispatcher = Dispatcher(config, path)
103+
_ = dispatcher._get_or_create_python_env()
104+
dispatcher.run_cmd("train {} {}".format(input, output))
105+
106+
# delete the virtualenv
107+
if dispatcher.python_env_path:
108+
logger.info(f"Removing virtualenv {dispatcher.python_env_path}")
109+
shutil.rmtree(dispatcher.python_env_path)
48110
@run_cmd.command("startup")
49111
@click.option("-p", "--path", required=True, help="Path to package directory containing fedn.yaml")
50112
@click.pass_context

0 commit comments

Comments
 (0)