Skip to content

Commit aec275d

Browse files
authored
Merge branch 'develop' into mm_structred_output
2 parents ae2d2b0 + 36dc734 commit aec275d

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

fastdeploy/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939

4040
T = TypeVar("T")
4141

42+
# [N,2] -> every line is [config_name, enable_xxx_name]
43+
# Make sure enable_xxx equal to config.enable_xxx
44+
ARGS_CORRECTION_LIST = [["early_stop_config", "enable_early_stop"], ["graph_optimization_config", "use_cudagraph"]]
45+
4246

4347
class EngineError(Exception):
4448
"""Base exception class for engine errors"""
@@ -361,8 +365,16 @@ def parse_args(self, args=None, namespace=None):
361365
namespace = argparse.Namespace()
362366
for key, value in filtered_config.items():
363367
setattr(namespace, key, value)
364-
365-
return super().parse_args(args=remaining_args, namespace=namespace)
368+
args = super().parse_args(args=remaining_args, namespace=namespace)
369+
370+
# Args correction
371+
for config_name, flag_name in ARGS_CORRECTION_LIST:
372+
if hasattr(args, config_name) and hasattr(args, flag_name):
373+
# config is a dict
374+
config = getattr(args, config_name, None)
375+
if config is not None and flag_name in config.keys():
376+
setattr(args, flag_name, config[flag_name])
377+
return args
366378

367379

368380
def resolve_obj_from_strname(strname: str):

0 commit comments

Comments
 (0)