|
39 | 39 |
|
40 | 40 | T = TypeVar("T")
|
41 | 41 |
|
| 42 | +# [N,2] -> every line is [config_name, enable_xxx_name] |
| 43 | +# Make sure enable_xxx equal to config.enable_xxx |
| 44 | +ARGS_CORRECTION_LIST = [["early_stop_config", "enable_early_stop"], ["graph_optimization_config", "use_cudagraph"]] |
| 45 | + |
42 | 46 |
|
43 | 47 | class EngineError(Exception):
|
44 | 48 | """Base exception class for engine errors"""
|
@@ -361,8 +365,16 @@ def parse_args(self, args=None, namespace=None):
|
361 | 365 | namespace = argparse.Namespace()
|
362 | 366 | for key, value in filtered_config.items():
|
363 | 367 | setattr(namespace, key, value)
|
364 |
| - |
365 |
| - return super().parse_args(args=remaining_args, namespace=namespace) |
| 368 | + args = super().parse_args(args=remaining_args, namespace=namespace) |
| 369 | + |
| 370 | + # Args correction |
| 371 | + for config_name, flag_name in ARGS_CORRECTION_LIST: |
| 372 | + if hasattr(args, config_name) and hasattr(args, flag_name): |
| 373 | + # config is a dict |
| 374 | + config = getattr(args, config_name, None) |
| 375 | + if config is not None and flag_name in config.keys(): |
| 376 | + setattr(args, flag_name, config[flag_name]) |
| 377 | + return args |
366 | 378 |
|
367 | 379 |
|
368 | 380 | def resolve_obj_from_strname(strname: str):
|
|
0 commit comments