Skip to content

Commit d052e61

Browse files
committed
Retry throttling on get ssm parameter
Resolves unhandled throttling exception.
1 parent a2df3a4 commit d052e61

File tree

1 file changed

+22
-9
lines changed
  • source/resources/playbooks/roles/SlurmCtl/files/opt/slurm/cluster/bin

1 file changed

+22
-9
lines changed

source/resources/playbooks/roles/SlurmCtl/files/opt/slurm/cluster/bin/SlurmPlugin.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def run(self):
165165
if matches:
166166
encoded_message = matches.group(1)
167167
logger.error(f"Encoded message:\n{encoded_message}")
168-
sts_client = boto3.client('sts', region_name=self.region)
168+
sts_client = plugin.sts_client[self.region]
169169
decoded_message = json.loads(sts_client.decode_authorization_message(EncodedMessage=encoded_message)['DecodedMessage'])
170170
logger.error(f"decoded_message:\n{json.dumps(decoded_message, indent=4)}")
171171
except Exception as e:
@@ -300,16 +300,19 @@ def __init__(self, slurm_config_file=f"/opt/slurm/config/slurm_config.json", reg
300300

301301
self.instance_types = None
302302

303+
# Create all of the boto3 clients in one place so can make sure that all client api calls get throttling retries.
303304
# Create first so that can publish metrics for unhandled exceptions
304305
self.cw = boto3.client('cloudwatch')
305-
self.ssm = boto3.client('ssm')
306306

307307
try:
308+
self.ssm_client = boto3.client('ssm')
308309
self.ec2 = {}
309310
self.ec2_describe_instances_paginator = {}
311+
self.sts_client = {}
310312
for region in self.compute_regions:
311313
self.ec2[region] = boto3.client('ec2', region_name=region)
312314
self.ec2_describe_instances_paginator[region] = self.ec2[region].get_paginator('describe_instances')
315+
self.sts_client[region] = boto3.client('sts', region_name=region)
313316
except:
314317
logger.exception('Unhandled exception in SlurmPlugin constructor')
315318
self.publish_cw_metrics(self.CW_UNHANDLED_PLUGIN_CONSTRUCTOR_EXCEPTION, 1, [])
@@ -563,7 +566,7 @@ def add_hostname_to_hostinfo(self, hostname):
563566

564567
ssm_parameter_name = f"/{self.config['STACK_NAME']}/SlurmNodeAmis/{distribution}/{distribution_major_version}/{architecture}/{hostinfo['region']}"
565568
try:
566-
hostinfo['ami'] = self.ssm.get_parameter(Name=ssm_parameter_name)["Parameter"]["Value"]
569+
hostinfo['ami'] = self.get_ssm_parameter(ssm_parameter_name)
567570
except Exception as e:
568571
logging.exception(f"Error getting ami from SSM parameter {ssm_parameter_name}")
569572
# Don't have a way of handling this.
@@ -787,8 +790,8 @@ def resume_region(self, region):
787790
az = self.az_ids[az_id]
788791
region = self.az_info[az]['region']
789792
subnet = self.az_info[az]['subnet']
790-
security_group_id = self.ssm.get_parameter(Name=f"/{self.config['STACK_NAME']}/SlurmNodeSecurityGroups/{region}")['Parameter']['Value']
791-
key_name = self.ssm.get_parameter(Name=f"/{self.config['STACK_NAME']}/SlurmNodeEc2KeyPairs/{region}")['Parameter']['Value']
793+
security_group_id = self.get_ssm_parameter(f"/{self.config['STACK_NAME']}/SlurmNodeSecurityGroups/{region}")
794+
key_name = self.get_ssm_parameter(f"/{self.config['STACK_NAME']}/SlurmNodeEc2KeyPairs/{region}")
792795
ami = hostinfo['ami']
793796
userData = userDataTemplate.render({
794797
'DOMAIN': self.config['DOMAIN'],
@@ -1304,7 +1307,7 @@ def stop_instanceIds(self, region, hostnames_to_stop, instanceIds_to_stop,
13041307
if matches:
13051308
encoded_message = matches.group(1)
13061309
logger.error(f"Encoded message:\n{encoded_message}")
1307-
sts_client = boto3.client('sts', region_name=region)
1310+
sts_client = self.sts_client[region]
13081311
decoded_message = json.loads(sts_client.decode_authorization_message(EncodedMessage=encoded_message)['DecodedMessage'])
13091312
logger.error(f"decoded_message:\n{json.dumps(decoded_message, indent=4)}")
13101313
else:
@@ -1351,7 +1354,7 @@ def terminate_instanceIds(self, region, hostnames_to_terminate, instanceIds_to_t
13511354
if matches:
13521355
encoded_message = matches.group(1)
13531356
logger.error(f"Encoded message:\n{encoded_message}")
1354-
sts_client = boto3.client('sts', region_name=region)
1357+
sts_client = self.sts_client[region]
13551358
decoded_message = json.loads(sts_client.decode_authorization_message(EncodedMessage=encoded_message)['DecodedMessage'])
13561359
logger.error(f"decoded_message:\n{json.dumps(decoded_message, indent=4)}")
13571360
else:
@@ -1904,10 +1907,10 @@ def get_az_info_from_instance_config(self, instance_config: dict) -> dict:
19041907
az_info = {}
19051908
for region, region_dict in instance_config['Regions'].items():
19061909
logger.debug(f"region: {region}")
1907-
ec2_client = boto3.client('ec2', region_name=region)
1910+
ec2_client = self.ec2[region]
19081911
for az_dict in region_dict['AZs']:
19091912
subnet = az_dict['Subnet']
1910-
subnet_info = ec2_client.describe_subnets(SubnetIds=[subnet])['Subnets'][0]
1913+
subnet_info = self.describe_subnets(ec2_client, {'SubnetIds': [subnet]})['Subnets'][0]
19111914
az = subnet_info['AvailabilityZone']
19121915
az_id = subnet_info['AvailabilityZoneId']
19131916
az_info[az] = {
@@ -2116,6 +2119,11 @@ def publish_cw(self):
21162119
return 1
21172120
return 0
21182121

2122+
@retry_ec2_throttling()
2123+
def describe_subnets(self, ec2_client, kwargs):
2124+
result = ec2_client.describe_subnets(**kwargs)
2125+
return result
2126+
21192127
@retry_ec2_throttling()
21202128
def paginate(self, paginator, kwargs):
21212129
result = paginator.paginate(**kwargs)
@@ -2135,3 +2143,8 @@ def stop_instances(self, region, kwargs):
21352143
def terminate_instances(self, region, kwargs):
21362144
result = self.ec2[region].terminate_instances(**kwargs)
21372145
return result
2146+
2147+
@retry_ec2_throttling()
2148+
def get_ssm_parameter(self, ssm_parameter_name):
2149+
value = self.ssm_client.get_parameter(Name=ssm_parameter_name)["Parameter"]["Value"]
2150+
return value

0 commit comments

Comments
 (0)