Skip to content

Commit 850be89

Browse files
committed
fix: deploy issue in cn region
1 parent 87df04c commit 850be89

File tree

7 files changed

+186
-50
lines changed

7 files changed

+186
-50
lines changed

src/dmaa/cfn/codepipeline/template.yaml

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,21 @@ Resources:
4444
Action:
4545
- ecr-public:*
4646
- sts:GetServiceBearerToken
47+
- cloudformation:DescribeStacks
48+
- cloudformation:CreateStack
49+
- cloudformation:DeleteStack
50+
- ec2:*
51+
- ecs:*
52+
- iam:CreateRole
53+
- iam:AttachRolePolicy
54+
- iam:PassRole
55+
- iam:PutRolePolicy
56+
- iam:DetachRolePolicy
57+
- iam:DeleteRole
58+
- iam:GetRole
59+
- lambda:*
60+
- logs:*
61+
- elasticloadbalancing:*
4762
Resource:
4863
- "*"
4964
ManagedPolicyArns:
@@ -183,8 +198,8 @@ Resources:
183198
- echo Build started on `date`
184199
build:
185200
commands:
186-
- echo "Building Amazon ECR image..."
187-
- |
201+
- |-
202+
echo "Building Amazon ECR image..."
188203
region=$ModelRegion
189204
model_s3_bucket=$ModelBucketName
190205
model_id=$ModelId
@@ -194,25 +209,26 @@ Resources:
194209
service=$ServiceType
195210
instance_type=$InstanceType
196211
extra_params=$ExtraParams
197-
- cd pipeline/
198-
- export PYTHONPATH=.:$PYTHONPATH
199-
- pip install --upgrade pip
200-
- pip install -r requirements.txt
201-
- python pipeline.py --region $region --model_id $model_id --model_tag $ModelTag --framework_type $FrameworkType --service_type $service --backend_type $backend_name --model_s3_bucket $model_s3_bucket --instance_type $instance_type --extra_params "$extra_params" --skip_deploy
202-
- cd ..
212+
cd pipeline/
213+
export PYTHONPATH=.:$PYTHONPATH
214+
pip install --upgrade pip
215+
pip install -r requirements.txt
216+
python pipeline.py --region $region --model_id $model_id --model_tag $ModelTag --framework_type $FrameworkType --service_type $service --backend_type $backend_name --model_s3_bucket $model_s3_bucket --instance_type $instance_type --extra_params "$extra_params" --skip_deploy
217+
cd ..
218+
echo pipeline build completed on `date`
203219
204220
post_build:
205221
commands:
206-
- SERVICE_TYPE=$(echo "$ServiceType" | tr '[:upper:]' '[:lower:]')
207-
- cp cfn/$ServiceType/template.yaml template.yaml
208-
- cp pipeline/parameters.json parameters.json
209-
- echo pipeline build completed on `date`, post deployment starting
210-
- if [ -f cfn/$ServiceType/post_build.py ]; then
222+
- |-
223+
SERVICE_TYPE=$(echo "$ServiceType" | tr '[:upper:]' '[:lower:]')
224+
cp cfn/$ServiceType/template.yaml template.yaml
225+
cp pipeline/parameters.json parameters.json
226+
if [ -f cfn/$ServiceType/post_build.py ]; then
211227
cp cfn/$ServiceType/post_build.py post_build.py
212228
python post_build.py --region $region --model_id $model_id --model_tag $ModelTag --framework_type $FrameworkType --service_type $service --backend_type $backend_name --model_s3_bucket $model_s3_bucket --instance_type $instance_type --extra_params "$extra_params"
213229
fi
214-
- cat parameters.json
215-
- echo Build completed on `date`
230+
cat parameters.json
231+
echo Build completed on `date`
216232
217233
artifacts:
218234
files:

src/dmaa/cfn/ecs/cluster.yaml

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@ Description: DMAA ECS Cluster - Ensure all associated models are deleted before
33
Parameters:
44
VPCID:
55
Type: AWS::EC2::VPC::Id
6-
Default: "vpc-043d6906ccc5614f9"
76
Description: The VPC ID to be used for the ECS cluster
87
Subnets:
98
Type: List<AWS::EC2::Subnet::Id>
10-
Default: "subnet-0d74bc4b17d7ab99b,subnet-0bd0680e698e529f2"
119
Description: The public subnets to be used for the ECS cluster
1210
Resources:
1311
ECSCluster:
@@ -65,10 +63,26 @@ Resources:
6563
- sts:AssumeRole
6664
Path: /
6765
ManagedPolicyArns:
68-
- arn:aws:iam::aws:policy/service-role/AmazonEC2ContainerServiceforEC2Role
69-
- arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore
70-
- arn:aws:iam::aws:policy/AmazonS3FullAccess
71-
- arn:aws:iam::aws:policy/AmazonECS_FullAccess
66+
- Fn::Join:
67+
- ""
68+
- - "arn:"
69+
- Ref: AWS::Partition
70+
- :iam::aws:policy/service-role/AmazonEC2ContainerServiceforEC2Role
71+
- Fn::Join:
72+
- ""
73+
- - "arn:"
74+
- Ref: AWS::Partition
75+
- :iam::aws:policy/AmazonSSMManagedInstanceCore
76+
- Fn::Join:
77+
- ""
78+
- - "arn:"
79+
- Ref: AWS::Partition
80+
- :iam::aws:policy/AmazonS3FullAccess
81+
- Fn::Join:
82+
- ""
83+
- - "arn:"
84+
- Ref: AWS::Partition
85+
- :iam::aws:policy/AmazonECS_FullAccess
7286
ECSTaskExecutionRole:
7387
Type: AWS::IAM::Role
7488
Properties:
@@ -82,7 +96,11 @@ Resources:
8296
- sts:AssumeRole
8397
Path: /
8498
ManagedPolicyArns:
85-
- arn:aws:iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy
99+
- Fn::Join:
100+
- ""
101+
- - "arn:"
102+
- Ref: AWS::Partition
103+
- :iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy
86104

87105
LambdaExecutionRole:
88106
Type: AWS::IAM::Role

src/dmaa/cfn/ecs/post_build.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
import boto3
22
import time
33
import json
4+
import os
45
import argparse
56

67
# Post build script for ECS, it will deploy the VPC and ECS cluster.
78

89
CFN_ROOT_PATH = 'cfn'
910
WAIT_SECONDS = 10
10-
CFN_ROOT_PATH = '../../cfn'
11+
# CFN_ROOT_PATH = '../../cfn'
12+
JSON_DOUBLE_QUOTE_REPLACE = '<!>'
13+
14+
def load_extra_params(string):
15+
string = string.replace(JSON_DOUBLE_QUOTE_REPLACE,'"')
16+
try:
17+
return json.loads(string)
18+
except json.JSONDecodeError:
19+
raise argparse.ArgumentTypeError(f"Invalid dictionary format: {string}")
20+
21+
def dump_extra_params(d:dict):
22+
return json.dumps(d).replace('"', JSON_DOUBLE_QUOTE_REPLACE)
1123

1224
def wait_for_stack_completion(client, stack_id, stack_name):
1325
while True:
@@ -29,14 +41,28 @@ def create_or_update_stack(client, stack_name, template_path, parameters=[]):
2941
try:
3042
response = client.describe_stacks(StackName=stack_name)
3143
stack_status = response['Stacks'][0]['StackStatus']
32-
# if stack_status == 'ROLLBACK_COMPLETE':
33-
# print(f"Stack {stack_name} is in ROLLBACK_COMPLETE state. Deleting the stack to allow for recreation.")
34-
# client.delete_stack(StackName=stack_name)
35-
# wait_for_stack_completion(client, stack_name, stack_name)
44+
if stack_status in ['ROLLBACK_COMPLETE', 'ROLLBACK_FAILED', 'DELETE_FAILED']:
45+
print(f"Stack {stack_name} is in {stack_status} state. Deleting the stack to allow for recreation.")
46+
client.delete_stack(StackName=stack_name)
47+
while True:
48+
try:
49+
response = client.describe_stacks(StackName=stack_name)
50+
stack_status = response['Stacks'][0]['StackStatus']
51+
if stack_status == 'DELETE_IN_PROGRESS':
52+
print(f"Stack {stack_name} is being deleted...")
53+
time.sleep(WAIT_SECONDS)
54+
else:
55+
raise Exception(f"Unexpected status {stack_status} while waiting for stack deletion.")
56+
except client.exceptions.ClientError as e:
57+
if 'does not exist' in str(e):
58+
print(f"Stack {stack_name} successfully deleted.")
59+
break
60+
else:
61+
raise
3662
while stack_status not in ['CREATE_COMPLETE', 'UPDATE_COMPLETE']:
3763
if stack_status in ['CREATE_IN_PROGRESS', 'UPDATE_IN_PROGRESS']:
3864
print(f"Stack {stack_name} is currently {stack_status}. Waiting for it to complete...")
39-
time.sleep(SLEEP_WAIT_SECONDS)
65+
time.sleep(WAIT_SECONDS)
4066
response = client.describe_stacks(StackName=stack_name)
4167
stack_status = response['Stacks'][0]['StackStatus']
4268
else:
@@ -108,27 +134,27 @@ def deploy_ecs_cluster_template(region, vpc_id, subnets):
108134

109135

110136
def post_build():
111-
parser = argparse.ArgumentParser(description="Post build script")
112-
parser.add_argument('--region', type=str, required=False, help='AWS region')
113-
parser.add_argument('--model_id', type=str, required=False, help='Model ID')
114-
parser.add_argument('--model_tag', type=str, required=False, help='Model tag')
115-
parser.add_argument('--framework_type', type=str, required=False, help='Framework type')
116-
parser.add_argument('--service_type', type=str, required=False, help='Service type')
117-
parser.add_argument('--backend_type', type=str, required=False, help='Backend type')
118-
parser.add_argument('--model_s3_bucket', type=str, required=False, help='Model S3 bucket')
119-
parser.add_argument('--instance_type', type=str, required=False, help='Instance type')
120-
parser.add_argument('--extra_params', type=str, required=False, help='Extra parameters')
137+
parser = argparse.ArgumentParser()
138+
parser.add_argument('--region', type=str, required=False)
139+
parser.add_argument('--model_id', type=str, required=False)
140+
parser.add_argument('--model_tag', type=str, required=False)
141+
parser.add_argument('--framework_type', type=str, required=False)
142+
parser.add_argument('--service_type', type=str, required=False)
143+
parser.add_argument('--backend_type', type=str, required=False)
144+
parser.add_argument('--model_s3_bucket', type=str, required=False)
145+
parser.add_argument('--instance_type', type=str, required=False)
146+
parser.add_argument('--extra_params', type=load_extra_params, required=False, default=os.environ.get("extra_params","{}"))
121147

122148
args = parser.parse_args()
123149

124-
extra_params = json.loads(args.extra_params) if args.extra_params else {}
150+
service_params = args.extra_params.get('service_params',{})
125151

126-
if 'VpcID' not in extra_params:
152+
if 'vpc_id' not in service_params:
127153
vpc_id, subnets = deploy_vpc_template(args.region)
128154
else:
129-
vpc_id = extra_params.get('VpcID')
130-
subnets = extra_params.get('Subnets')
131-
update_parameters_file('parameters.json', {'VpcID': vpc_id, 'Subnets': subnets})
155+
vpc_id = service_params.get('vpc_id')
156+
subnets = service_params.get('subnet_ids')
157+
update_parameters_file('parameters.json', {'VPCID': vpc_id, 'Subnets': subnets})
132158

133159
deploy_ecs_cluster_template(args.region, vpc_id, subnets)
134160

src/dmaa/commands/deploy.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def deploy(
228228
# console.print("[bold blue]Checking AWS environment...[/bold blue]")
229229

230230
region = get_current_region()
231-
231+
vpc_id = None
232232
# ask model id
233233
model_id = ask_model_id(region,model_id=model_id)
234234

@@ -269,7 +269,75 @@ def deploy(
269269
if not check_service_support_on_cn_region(service_type,region):
270270
raise ServiceNotSupported(region, service_type=service_type)
271271

272-
#
272+
if Service.get_service_from_service_type(service_type).need_vpc:
273+
client = boto3.client('ec2', region_name=region)
274+
vpcs = []
275+
dmaa_default_vpc = None
276+
paginator = client.get_paginator('describe_vpcs')
277+
for page in paginator.paginate():
278+
for vpc in page['Vpcs']:
279+
if any(tag['Key'] == 'Name' and tag['Value'] == 'DMAA-vpc' for tag in vpc.get('Tags', [])):
280+
dmaa_default_vpc = vpc
281+
continue
282+
vpcs.append(vpc)
283+
else:
284+
for vpc in vpcs:
285+
vpc_name = next((tag['Value'] for tag in vpc.get('Tags', []) if tag.get('Key') == 'Name'), None)
286+
vpc['Name'] = vpc_name if vpc_name else '-'
287+
dmaa_vpc = select_with_help(
288+
"Select the VPC (Virtual Private Cloud) you want to deploy the ESC service:",
289+
choices=[
290+
Choice(
291+
title=f"{dmaa_default_vpc['VpcId']} ({dmaa_default_vpc['CidrBlock']}) (DMAA-vpc)" if dmaa_default_vpc else 'Create a new VPC',
292+
description='Use the existing DMAA-VPC for the new model deployment (recommended)' if dmaa_default_vpc else 'Create a new VPC with two public subnets and a S3 Endpoint for the model deployment. Select this option if you do not know what is VPC',
293+
)
294+
] + [
295+
Choice(
296+
title=f"{vpc['VpcId']} ({vpc['CidrBlock']}) ({vpc['Name']})",
297+
description="Custom VPC requirement: A NAT Gateway or S3 Endpoint, with at least two public and two private subnet.",
298+
)
299+
for vpc in vpcs
300+
],
301+
style=custom_style
302+
).ask()
303+
304+
vpc_id = None
305+
selected_subnet_ids = []
306+
if 'Create a new VPC' == dmaa_vpc:
307+
pass
308+
elif 'DMAA-vpc' in dmaa_vpc:
309+
vpc_id = dmaa_vpc.split()[0]
310+
paginator = client.get_paginator('describe_subnets')
311+
for page in paginator.paginate(Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}]):
312+
for subnet in page['Subnets']:
313+
selected_subnet_ids.append(subnet['SubnetId'])
314+
else:
315+
vpc_id = dmaa_vpc.split()[0]
316+
subnets = []
317+
paginator = client.get_paginator('describe_subnets')
318+
for page in paginator.paginate(Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}]):
319+
for subnet in page['Subnets']:
320+
subnets.append(subnet)
321+
if not subnets:
322+
console.print("[bold red]No subnets found in the selected VPC.[/bold red]")
323+
raise typer.Exit(0)
324+
else:
325+
for subnet in subnets:
326+
subnet_name = next((tag['Value'] for tag in subnet.get('Tags', []) if tag.get('Key') == 'Name'), None)
327+
subnet['Name'] = subnet_name if subnet_name else '-'
328+
selected_subnet = questionary.checkbox(
329+
"Select multiple subnets for the model deployment:",
330+
choices=[
331+
f"{subnet['SubnetId']} ({subnet['CidrBlock']}) ({subnet['Name']})"
332+
for subnet in subnets
333+
],
334+
style=custom_style
335+
).ask()
336+
if selected_subnet is None:
337+
raise typer.Exit(0)
338+
else:
339+
for subnet in selected_subnet:
340+
selected_subnet_ids.append(subnet.split()[0])
273341

274342
# support instance
275343
supported_instances = model.supported_instances
@@ -395,12 +463,18 @@ def deploy(
395463
except json.JSONDecodeError as e:
396464
console.print("[red]Invalid JSON format. Please try again.[/red]")
397465

466+
# append extra params for VPC and subnets
467+
if vpc_id:
468+
if 'service_params' not in extra_params:
469+
extra_params['service_params'] = {}
470+
extra_params['service_params']['vpc_id'] = vpc_id
471+
extra_params['service_params']['subnet_ids'] = ",".join(selected_subnet_ids)
398472
# model tag
399473
if model_tag==MODEL_DEFAULT_TAG and not skip_confirm and not service_type == ServiceType.LOCAL:
400474
while True:
401475
model_tag = questionary.text(
402476
"(Optional) Add a model deployment tag (custom label), you can skip by pressing Enter:",
403-
default=""
477+
default="dev"
404478
).ask()
405479
if model_tag == MODEL_DEFAULT_TAG:
406480
console.print(f"[bold blue]invalid model tag: {model_tag}. Please try again.[/bold blue]")

src/dmaa/models/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class Service(ModelBase):
101101
name: str
102102
description: str
103103
support_cn_region: bool
104-
# support_custom_vpc: bool = False
104+
need_vpc: bool = False
105105

106106
# class vars
107107
service_name_maps: ClassVar[dict] = {}

src/dmaa/models/services.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@
5757
name = "Amazon EC2",
5858
service_type=ServiceType.EC2,
5959
description="Amazon Elastic Compute Cloud (Amazon EC2) provides scalable computing capacity in the Amazon Web Services (AWS) cloud.",
60-
support_cn_region = False
60+
support_cn_region = False,
61+
need_vpc = True
6162
)
6263

6364
ecs_service = Service(
@@ -77,7 +78,8 @@
7778
name = "Amazon ECS",
7879
service_type=ServiceType.ECS,
7980
description="Amazon ECS is a fully managed service that provides scalable and reliable container orchestration for your applications.",
80-
support_cn_region = False
81+
support_cn_region = True,
82+
need_vpc = True
8183
)
8284

8385

src/dmaa/utils/aws_service_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def get_account_id():
7575

7676

7777
def create_s3_bucket(bucket_name, region):
78-
s3 = boto3.client("s3")
78+
s3 = boto3.client("s3", region_name=region)
7979
try:
8080
s3.head_bucket(Bucket=bucket_name)
8181
except:

0 commit comments

Comments
 (0)