@@ -165,7 +165,7 @@ def run(self):
165
165
if matches :
166
166
encoded_message = matches .group (1 )
167
167
logger .error (f"Encoded message:\n { encoded_message } " )
168
- sts_client = boto3 . client ( 'sts' , region_name = self .region )
168
+ sts_client = plugin . sts_client [ self .region ]
169
169
decoded_message = json .loads (sts_client .decode_authorization_message (EncodedMessage = encoded_message )['DecodedMessage' ])
170
170
logger .error (f"decoded_message:\n { json .dumps (decoded_message , indent = 4 )} " )
171
171
except Exception as e :
@@ -300,16 +300,19 @@ def __init__(self, slurm_config_file=f"/opt/slurm/config/slurm_config.json", reg
300
300
301
301
self .instance_types = None
302
302
303
+ # Create all of the boto3 clients in one place so can make sure that all client api calls get throttling retries.
303
304
# Create first so that can publish metrics for unhandled exceptions
304
305
self .cw = boto3 .client ('cloudwatch' )
305
- self .ssm = boto3 .client ('ssm' )
306
306
307
307
try :
308
+ self .ssm_client = boto3 .client ('ssm' )
308
309
self .ec2 = {}
309
310
self .ec2_describe_instances_paginator = {}
311
+ self .sts_client = {}
310
312
for region in self .compute_regions :
311
313
self .ec2 [region ] = boto3 .client ('ec2' , region_name = region )
312
314
self .ec2_describe_instances_paginator [region ] = self .ec2 [region ].get_paginator ('describe_instances' )
315
+ self .sts_client [region ] = boto3 .client ('sts' , region_name = region )
313
316
except :
314
317
logger .exception ('Unhandled exception in SlurmPlugin constructor' )
315
318
self .publish_cw_metrics (self .CW_UNHANDLED_PLUGIN_CONSTRUCTOR_EXCEPTION , 1 , [])
@@ -563,7 +566,7 @@ def add_hostname_to_hostinfo(self, hostname):
563
566
564
567
ssm_parameter_name = f"/{ self .config ['STACK_NAME' ]} /SlurmNodeAmis/{ distribution } /{ distribution_major_version } /{ architecture } /{ hostinfo ['region' ]} "
565
568
try :
566
- hostinfo ['ami' ] = self .ssm . get_parameter ( Name = ssm_parameter_name )[ "Parameter" ][ "Value" ]
569
+ hostinfo ['ami' ] = self .get_ssm_parameter ( ssm_parameter_name )
567
570
except Exception as e :
568
571
logging .exception (f"Error getting ami from SSM parameter { ssm_parameter_name } " )
569
572
# Don't have a way of handling this.
@@ -787,8 +790,8 @@ def resume_region(self, region):
787
790
az = self .az_ids [az_id ]
788
791
region = self .az_info [az ]['region' ]
789
792
subnet = self .az_info [az ]['subnet' ]
790
- security_group_id = self .ssm . get_parameter ( Name = f"/{ self .config ['STACK_NAME' ]} /SlurmNodeSecurityGroups/{ region } " )[ 'Parameter' ][ 'Value' ]
791
- key_name = self .ssm . get_parameter ( Name = f"/{ self .config ['STACK_NAME' ]} /SlurmNodeEc2KeyPairs/{ region } " )[ 'Parameter' ][ 'Value' ]
793
+ security_group_id = self .get_ssm_parameter ( f"/{ self .config ['STACK_NAME' ]} /SlurmNodeSecurityGroups/{ region } " )
794
+ key_name = self .get_ssm_parameter ( f"/{ self .config ['STACK_NAME' ]} /SlurmNodeEc2KeyPairs/{ region } " )
792
795
ami = hostinfo ['ami' ]
793
796
userData = userDataTemplate .render ({
794
797
'DOMAIN' : self .config ['DOMAIN' ],
@@ -1304,7 +1307,7 @@ def stop_instanceIds(self, region, hostnames_to_stop, instanceIds_to_stop,
1304
1307
if matches :
1305
1308
encoded_message = matches .group (1 )
1306
1309
logger .error (f"Encoded message:\n { encoded_message } " )
1307
- sts_client = boto3 . client ( 'sts' , region_name = region )
1310
+ sts_client = self . sts_client [ region ]
1308
1311
decoded_message = json .loads (sts_client .decode_authorization_message (EncodedMessage = encoded_message )['DecodedMessage' ])
1309
1312
logger .error (f"decoded_message:\n { json .dumps (decoded_message , indent = 4 )} " )
1310
1313
else :
@@ -1351,7 +1354,7 @@ def terminate_instanceIds(self, region, hostnames_to_terminate, instanceIds_to_t
1351
1354
if matches :
1352
1355
encoded_message = matches .group (1 )
1353
1356
logger .error (f"Encoded message:\n { encoded_message } " )
1354
- sts_client = boto3 . client ( 'sts' , region_name = region )
1357
+ sts_client = self . sts_client [ region ]
1355
1358
decoded_message = json .loads (sts_client .decode_authorization_message (EncodedMessage = encoded_message )['DecodedMessage' ])
1356
1359
logger .error (f"decoded_message:\n { json .dumps (decoded_message , indent = 4 )} " )
1357
1360
else :
@@ -1904,10 +1907,10 @@ def get_az_info_from_instance_config(self, instance_config: dict) -> dict:
1904
1907
az_info = {}
1905
1908
for region , region_dict in instance_config ['Regions' ].items ():
1906
1909
logger .debug (f"region: { region } " )
1907
- ec2_client = boto3 . client ( ' ec2' , region_name = region )
1910
+ ec2_client = self . ec2 [ region ]
1908
1911
for az_dict in region_dict ['AZs' ]:
1909
1912
subnet = az_dict ['Subnet' ]
1910
- subnet_info = ec2_client .describe_subnets (SubnetIds = [subnet ])['Subnets' ][0 ]
1913
+ subnet_info = self .describe_subnets (ec2_client , { ' SubnetIds' : [subnet ]} )['Subnets' ][0 ]
1911
1914
az = subnet_info ['AvailabilityZone' ]
1912
1915
az_id = subnet_info ['AvailabilityZoneId' ]
1913
1916
az_info [az ] = {
@@ -2116,6 +2119,11 @@ def publish_cw(self):
2116
2119
return 1
2117
2120
return 0
2118
2121
2122
+ @retry_ec2_throttling ()
2123
+ def describe_subnets (self , ec2_client , kwargs ):
2124
+ result = ec2_client .describe_subnets (** kwargs )
2125
+ return result
2126
+
2119
2127
@retry_ec2_throttling ()
2120
2128
def paginate (self , paginator , kwargs ):
2121
2129
result = paginator .paginate (** kwargs )
@@ -2135,3 +2143,8 @@ def stop_instances(self, region, kwargs):
2135
2143
def terminate_instances (self , region , kwargs ):
2136
2144
result = self .ec2 [region ].terminate_instances (** kwargs )
2137
2145
return result
2146
+
2147
+ @retry_ec2_throttling ()
2148
+ def get_ssm_parameter (self , ssm_parameter_name ):
2149
+ value = self .ssm_client .get_parameter (Name = ssm_parameter_name )["Parameter" ]["Value" ]
2150
+ return value
0 commit comments