Skip to content

Commit ee86e69

Browse files
authored
feat: support model deployment in parallel (#166)
* feat: support model deployment in parallel * feat: update status for parallel execution handling * fix: update typer version * fix: code cleanup
1 parent 494d831 commit ee86e69

File tree

8 files changed

+782
-191
lines changed

8 files changed

+782
-191
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ exclude = [".venv"]
1919

2020
[tool.poetry.dependencies]
2121
python = "^3.9"
22-
typer = {extras = ["all"], version = "0.15.3"}
22+
typer = "0.15.3"
2323
click = "8.0.4"
2424
rich = "14.0.0"
2525
boto3 = "^1.35.0"

src/emd/cfn/codepipeline/template.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ Resources:
276276
- Name: CreateTime
277277
- Name: Region
278278
PipelineType: V2
279-
ExecutionMode: QUEUED
279+
ExecutionMode: PARALLEL
280280
Stages:
281281
- Name: Source
282282
Actions:

src/emd/commands/destroy.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,42 @@
1414
layout = make_layout()
1515

1616

17-
#@app.callback(invoke_without_command=True)(invoke_without_command=True)
1817
@app.callback(invoke_without_command=True)
1918
@catch_aws_credential_errors
2019
@check_emd_env_exist
2120
@load_aws_profile
2221
def destroy(
23-
model_id: Annotated[
24-
str, typer.Argument(help="Model ID"),
25-
],
26-
model_tag: Annotated[
27-
str, typer.Argument(help="Model tag")
28-
] = MODEL_DEFAULT_TAG
29-
):
30-
# console.print("[bold blue]Checking AWS environment...[/bold blue]")
31-
sdk_destroy(model_id,model_tag=model_tag,waiting_until_complete=True)
22+
model_identifier: Annotated[
23+
str,
24+
typer.Argument(
25+
help="Model identifier in format 'model_id/model_tag' (e.g., 'Qwen2.5-0.5B-Instruct/d2')"
26+
)
27+
]
28+
):
29+
"""
30+
Destroy a model deployment.
31+
32+
Examples:
33+
emd destroy Qwen2.5-0.5B-Instruct/d2
34+
emd destroy Qwen2.5-VL-32B-Instruct/twopath
35+
emd destroy DeepSeek-R1-0528-Qwen3-8B/dev
36+
"""
37+
try:
38+
console.print(f"[yellow]Destroying model deployment: {model_identifier}[/yellow]")
39+
40+
# Use the new SDK format
41+
sdk_destroy(model_identifier=model_identifier, waiting_until_complete=True)
42+
43+
console.print(f"[green]✅ Model deployment '{model_identifier}' has been successfully deleted[/green]")
44+
console.print("[dim]The model stack and all associated resources have been removed[/dim]")
45+
46+
except ValueError as e:
47+
console.print(f"[red]❌ Invalid format: {e}[/red]")
48+
console.print("[yellow]Expected format: 'model_id/model_tag'[/yellow]")
49+
console.print("[yellow]Examples:[/yellow]")
50+
console.print(" [cyan]emd destroy Qwen2.5-0.5B-Instruct/d2[/cyan]")
51+
console.print(" [cyan]emd destroy Qwen2.5-VL-32B-Instruct/twopath[/cyan]")
52+
raise typer.Exit(1)
53+
except Exception as e:
54+
console.print(f"[red]❌ Failed to destroy model deployment: {e}[/red]")
55+
raise typer.Exit(1)

src/emd/commands/status.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from emd.utils.logger_utils import make_layout
99
from rich.console import Console
1010
from rich.table import Table
11+
from rich.spinner import Spinner
12+
from rich.live import Live
1113

1214
app = typer.Typer(pretty_exceptions_enable=False)
1315
console = Console()
@@ -26,24 +28,29 @@ def status(
2628
str, typer.Argument(help="Model tag")
2729
] = MODEL_DEFAULT_TAG,
2830
):
29-
ret = get_model_status(model_id, model_tag=model_tag)
31+
# Show loading indicator while fetching model status
32+
with console.status("[bold green]Fetching model status...", spinner="dots"):
33+
ret = get_model_status(model_id, model_tag=model_tag)
34+
3035
inprogress = ret['inprogress']
3136
completed = ret['completed']
3237

3338
data = []
39+
# Process all in-progress executions (now includes ALL parallel executions)
3440
for d in inprogress:
3541
if d['status'] == "Stopped":
3642
continue
3743
data.append({
3844
"model_id": d['model_id'],
3945
"model_tag": d['model_tag'],
40-
"status": f"{d['status']} ({d['stage_name']})",
41-
"service_type": d['service_type'],
42-
"instance_type": d['instance_type'],
43-
"create_time": d['create_time'],
44-
"outputs": d['outputs'],
46+
"status": f"{d['status']} ({d['stage_name']})" if d.get('stage_name') else d['status'],
47+
"service_type": d.get('service_type', ''),
48+
"instance_type": d.get('instance_type', ''),
49+
"create_time": d.get('create_time', ''),
50+
"outputs": d.get('outputs', ''), # Use .get() to handle missing outputs field
4551
})
4652

53+
# Process completed models
4754
for d in completed:
4855
data.append({
4956
"model_id": d['model_id'],
@@ -79,16 +86,14 @@ def status(
7986
# Display the Models section
8087
console.print("\nModels", style="bold")
8188

82-
# Create a custom box style without vertical lines
83-
8489
# Create a single table for all models with normal horizontal lines but no vertical lines
8590
models_table = Table(show_header=False, expand=True)
8691

8792
# Add two columns for name/value pairs
8893
models_table.add_column(justify="left", style="cyan", width=22)
8994
models_table.add_column(justify="left", overflow="fold")
9095

91-
# Add each model to the table
96+
# Add each model to the table (now shows ALL parallel executions)
9297
for model_data in data:
9398
# Add model name as a name/value pair with bold styling
9499
model_name = f"{model_data['model_id']}/{model_data['model_tag']}"

src/emd/sdk/destroy.py

Lines changed: 150 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import boto3
22
import time
3+
from typing import Union, Tuple
34
from emd.utils.logger_utils import get_logger
45
from .status import get_destroy_status
56
from emd.constants import (
@@ -18,76 +19,207 @@
1819
from emd.models.utils.constants import ServiceType
1920
from emd.models import Model
2021
from emd.utils.aws_service_utils import get_current_region
22+
2123
logger = get_logger(__name__)
2224

2325

26+
def parse_model_identifier(model_identifier: str) -> Tuple[str, str]:
27+
"""
28+
Parse model identifier in format 'model_id/model_tag'
29+
30+
Args:
31+
model_identifier: String in format 'model_id/model_tag' or just 'model_id'
32+
33+
Returns:
34+
Tuple of (model_id, model_tag)
35+
36+
Raises:
37+
ValueError: If format is invalid
38+
39+
Examples:
40+
parse_model_identifier('Qwen2.5-0.5B-Instruct/d2') -> ('Qwen2.5-0.5B-Instruct', 'd2')
41+
parse_model_identifier('Qwen2.5-0.5B-Instruct') -> ('Qwen2.5-0.5B-Instruct', 'dev')
42+
"""
43+
if not model_identifier or not model_identifier.strip():
44+
raise ValueError("Model identifier cannot be empty")
45+
46+
model_identifier = model_identifier.strip()
47+
48+
if '/' not in model_identifier:
49+
# Backward compatibility: treat as model_id with default tag
50+
return model_identifier, MODEL_DEFAULT_TAG
51+
52+
parts = model_identifier.split('/')
53+
if len(parts) != 2:
54+
raise ValueError(
55+
f"Invalid format: '{model_identifier}'. "
56+
f"Expected format: 'model_id/model_tag' (e.g., 'Qwen2.5-0.5B-Instruct/d2')"
57+
)
58+
59+
model_id, model_tag = parts
60+
if not model_id.strip():
61+
raise ValueError("Model ID cannot be empty")
62+
if not model_tag.strip():
63+
raise ValueError("Model tag cannot be empty")
64+
65+
return model_id.strip(), model_tag.strip()
66+
67+
2468
def stop_pipeline_execution(
25-
model_id:str,
26-
model_tag:str,
27-
pipeline_name=CODEPIPELINE_NAME,
28-
waiting_until_complete=True
69+
model_id: str,
70+
model_tag: str,
71+
pipeline_name: str = CODEPIPELINE_NAME,
72+
waiting_until_complete: bool = True
2973
):
74+
"""
75+
Stop an active pipeline execution for a model.
76+
77+
Args:
78+
model_id: Model ID
79+
model_tag: Model tag
80+
pipeline_name: Name of the CodePipeline
81+
waiting_until_complete: Whether to wait for the stop to complete
82+
"""
83+
logger.info(f"Checking for active pipeline executions for model: {model_id}, tag: {model_tag}")
84+
3085
active_executuion_infos = get_pipeline_active_executions(
3186
pipeline_name=pipeline_name
3287
)
3388
active_executuion_infos_d = {
34-
Model.get_model_stack_name_prefix(d['model_id'],d['model_tag']):d for d in active_executuion_infos
89+
Model.get_model_stack_name_prefix(d['model_id'], d['model_tag']): d
90+
for d in active_executuion_infos
3591
}
36-
cur_uuid = Model.get_model_stack_name_prefix(model_id,model_tag)
92+
93+
cur_uuid = Model.get_model_stack_name_prefix(model_id, model_tag)
94+
logger.info(f"Looking for pipeline execution with key: {cur_uuid}")
95+
3796
if cur_uuid in active_executuion_infos_d:
38-
pipeline_execution_id = active_executuion_infos_d[cur_uuid]['pipeline_execution_id']
97+
execution_info = active_executuion_infos_d[cur_uuid]
98+
pipeline_execution_id = execution_info['pipeline_execution_id']
99+
100+
logger.info(f"Found active pipeline execution: {pipeline_execution_id}")
101+
logger.info(f"Current status: {execution_info.get('status', 'Unknown')}")
102+
39103
client = boto3.client('codepipeline', region_name=get_current_region())
40104
try:
41105
client.stop_pipeline_execution(
42106
pipelineName=pipeline_name,
43107
pipelineExecutionId=pipeline_execution_id
44108
)
109+
logger.info(f"Stop request sent for pipeline execution: {pipeline_execution_id}")
45110
except client.exceptions.DuplicatedStopRequestException as e:
46-
logger.error(e)
111+
logger.warning(f"Stop request already sent for execution {pipeline_execution_id}: {e}")
112+
except Exception as e:
113+
logger.error(f"Failed to stop pipeline execution {pipeline_execution_id}: {e}")
114+
raise
115+
47116
if waiting_until_complete:
117+
logger.info("Waiting for pipeline execution to stop...")
48118
while True:
49119
execution_info = get_pipeline_execution_info(
50120
pipeline_name=pipeline_name,
51121
pipeline_execution_id=pipeline_execution_id,
52122
)
53-
logger.info(f"pipeline execution status: {execution_info['status']}")
54-
if execution_info['status'] == 'Stopped':
123+
current_status = execution_info['status']
124+
logger.info(f"Pipeline execution status: {current_status}")
125+
126+
if current_status == 'Stopped':
127+
logger.info("Pipeline execution stopped successfully")
128+
break
129+
elif current_status in ['Succeeded', 'Failed', 'Cancelled']:
130+
logger.info(f"Pipeline execution completed with status: {current_status}")
55131
break
132+
56133
time.sleep(5)
57134
else:
58-
logger.error(f"model: {model_id}, model_tag: {model_tag} not found in pipeline executions.")
135+
logger.warning(f"No active pipeline execution found for model: {model_id}, tag: {model_tag}")
136+
logger.info(f"Available active executions: {list(active_executuion_infos_d.keys())}")
59137

60138

61139
def destroy_ecs(model_id,model_tag,stack_name):
62140
cf_client = boto3.client('cloudformation', region_name=get_current_region())
63141
cf_client.delete_stack(StackName=stack_name)
64142

65-
def destroy(model_id:str,model_tag=MODEL_DEFAULT_TAG,waiting_until_complete=True):
143+
def destroy(
144+
model_id: Union[str, None] = None,
145+
model_tag: str = MODEL_DEFAULT_TAG,
146+
model_identifier: Union[str, None] = None,
147+
waiting_until_complete: bool = True
148+
):
149+
"""
150+
Destroy a model deployment.
151+
152+
Args:
153+
model_id: Model ID (legacy format)
154+
model_tag: Model tag (legacy format)
155+
model_identifier: Model identifier in 'model_id/model_tag' format (new format)
156+
waiting_until_complete: Whether to wait for deletion to complete
157+
158+
Examples:
159+
# New format (recommended)
160+
destroy(model_identifier='Qwen2.5-0.5B-Instruct/d2')
161+
162+
# Legacy format (still supported)
163+
destroy(model_id='Qwen2.5-0.5B-Instruct', model_tag='d2')
164+
165+
Raises:
166+
ValueError: If neither format is provided or format is invalid
167+
"""
66168
check_env_stack_exist_and_complete()
67-
stack_name = Model.get_model_stack_name_prefix(model_id,model_tag=model_tag)
169+
170+
# Handle different input formats
171+
if model_identifier is not None:
172+
if model_id is not None:
173+
raise ValueError("Cannot specify both model_identifier and model_id. Use either the new format (model_identifier='model_id/model_tag') or legacy format (model_id='model_id', model_tag='model_tag')")
174+
175+
# Parse new format
176+
try:
177+
model_id, model_tag = parse_model_identifier(model_identifier)
178+
logger.info(f"Parsed model identifier '{model_identifier}' -> model_id='{model_id}', model_tag='{model_tag}'")
179+
except ValueError as e:
180+
logger.error(f"Invalid model identifier format: {e}")
181+
raise
182+
183+
elif model_id is not None:
184+
# Legacy format
185+
logger.info(f"Using legacy format -> model_id='{model_id}', model_tag='{model_tag}'")
186+
else:
187+
raise ValueError("Must specify either model_identifier (new format) or model_id (legacy format)")
188+
189+
stack_name = Model.get_model_stack_name_prefix(model_id, model_tag=model_tag)
190+
logger.info(f"Target stack name: {stack_name}")
191+
68192
if not check_stack_exists(stack_name):
69-
stop_pipeline_execution(model_id,model_tag,waiting_until_complete=waiting_until_complete)
193+
logger.info(f"Stack {stack_name} does not exist, checking for active pipeline executions...")
194+
stop_pipeline_execution(model_id, model_tag, waiting_until_complete=waiting_until_complete)
70195
return
71196

72197
stack_info = get_stack_info(stack_name)
73198
parameters = stack_info['parameters']
74199
if parameters['ServiceType'] == ServiceType.ECS:
75-
return destroy_ecs(model_id, model_tag,stack_name)
200+
logger.info(f"Destroying ECS service for stack: {stack_name}")
201+
return destroy_ecs(model_id, model_tag, stack_name)
76202

77203
cf_client = boto3.client('cloudformation', region_name=get_current_region())
78204
cf_client.delete_stack(StackName=stack_name)
79205

80-
logger.info(f"Delete stack initiated: {stack_name}")
206+
logger.info(f"CloudFormation stack deletion started: {stack_name}")
207+
logger.info("Deleting model infrastructure (compute instances, load balancers, security groups, etc.)")
208+
81209
# check delete status
82210
if waiting_until_complete:
211+
logger.info("Waiting for stack deletion to complete...")
83212
while True:
84213
status_info = get_destroy_status(stack_name)
85214
status = status_info['status']
86215
status_code = status_info['status_code']
87216
if status_code == 0:
88217
break
89-
logger.info(f'stack delete status: {status}')
218+
logger.info(f'Stack deletion progress: {status}')
90219
time.sleep(5)
220+
91221
if status == EMD_STACK_NOT_EXISTS_STATUS:
92222
status = "DELETE_COMPLETED"
93-
logger.info(f'stack delete status: {status}')
223+
logger.info("✅ Model deployment successfully deleted - all resources have been removed")
224+
else:
225+
logger.info(f'Stack deletion completed with status: {status}')

0 commit comments

Comments
 (0)