Skip to content

Service sdk sigv4 #1381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 19, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ data class KotlinDependency(
val KOTLINX_CBOR_SERDE = KotlinDependency(GradleConfiguration.Implementation, "kotlinx.serialization", "org.jetbrains.kotlinx", "kotlinx-serialization-cbor", KOTLINX_VERSION)
val KOTLINX_JSON_SERDE = KotlinDependency(GradleConfiguration.Implementation, "kotlinx.serialization.json", "org.jetbrains.kotlinx", "kotlinx-serialization-json", KOTLINX_VERSION)
val KTOR_SERVER_AUTH = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.auth", "io.ktor", "ktor-server-auth", KTOR_VERSION)
val KTOR_SERVER_DOUBLE_RECEIVE = KotlinDependency(GradleConfiguration.Implementation, "io.ktor.server.plugins.doublereceive", "io.ktor", "ktor-server-double-receive-jvm", KTOR_VERSION)
}

override fun getDependencies(): List<SymbolDependency> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ object RuntimeTypes {

object Net : RuntimeTypePackage(KotlinDependency.CORE, "net") {
val Host = symbol("Host")
val Scheme = symbol("Scheme")

object Url : RuntimeTypePackage(KotlinDependency.CORE, "net.url") {
val QueryParameters = symbol("QueryParameters")
Expand Down Expand Up @@ -393,6 +394,7 @@ object RuntimeTypes {
val mergeAuthOptions = symbol("mergeAuthOptions")
val sigV4 = symbol("sigV4")
val sigV4A = symbol("sigV4A")
val SignHttpRequest = symbol("SignHttpRequest")
}

object AwsSigningCrt : RuntimeTypePackage(KotlinDependency.AWS_SIGNING_CRT) {
Expand Down Expand Up @@ -530,12 +532,15 @@ object RuntimeTypes {

val requestReceive = symbol("receive", "request")
val requestUri = symbol("uri", "request")
val requestHeader = symbol("header", "request")
val requestHttpMethod = symbol("httpMethod", "request")
val requestApplicationRequest = symbol("ApplicationRequest", "request")
val requestContentLength = symbol("contentLength", "request")
val requestContentType = symbol("contentType", "request")
val requestAcceptItems = symbol("acceptItems", "request")
val requestPath = symbol("path", "request")

val responseResponse = symbol("respond", "response")
val responseResponseText = symbol("respondText", "response")
val responseRespondBytes = symbol("respondBytes", "response")
}
Expand All @@ -560,6 +565,8 @@ object RuntimeTypes {
val HttpStatusCode = symbol("HttpStatusCode")
val parseAndSortHeader = symbol("parseAndSortHeader")
val HttpHeaders = symbol("HttpHeaders")
val HeadersBuilder = symbol("HeadersBuilder")
val Parameters = symbol("Parameters")
val Cbor = symbol("Cbor", "ContentType.Application")
val Json = symbol("Json", "ContentType.Application")
val Any = symbol("Any", "ContentType.Application")
Expand Down Expand Up @@ -597,6 +604,15 @@ object RuntimeTypes {
val authenticate = symbol("authenticate")
val Principal = symbol("Principal")
val bearer = symbol("bearer")
val AuthenticationConfig = symbol("AuthenticationConfig")
val AuthenticationProvider = symbol("AuthenticationProvider")
val AuthenticationFailedCause = symbol("AuthenticationFailedCause")
val AuthenticationContext = symbol("AuthenticationContext")
val AuthenticationStrategy = symbol("AuthenticationStrategy")
}

object KtorServerDoubleReceive : RuntimeTypePackage(KotlinDependency.KTOR_SERVER_DOUBLE_RECEIVE) {
val DoubleReceive = symbol("DoubleReceive")
}

object KotlinxCborSerde : RuntimeTypePackage(KotlinDependency.KOTLINX_CBOR_SERDE) {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ internal interface ServiceStubGenerator {
}

internal abstract class AbstractStubGenerator(
protected val ctx: GenerationContext,
protected val delegator: KotlinDelegator,
protected val fileManifest: FileManifest,
val ctx: GenerationContext,
val delegator: KotlinDelegator,
val fileManifest: FileManifest,
) : ServiceStubGenerator {

protected val serviceShape = ctx.settings.getService(ctx.model)
protected val operations = TopDownIndex.of(ctx.model)
val serviceShape = ctx.settings.getService(ctx.model)
val operations = TopDownIndex.of(ctx.model)
.getContainedOperations(serviceShape)
.sortedBy { it.defaultName() }

protected val pkgName = ctx.settings.pkg.name
val pkgName = ctx.settings.pkg.name

final override fun render() {
renderServiceFrameworkConfig()
Expand Down Expand Up @@ -101,6 +101,7 @@ internal abstract class AbstractStubGenerator(
withBlock("private data class Data(", ")") {
write("val port: Int,")
write("val engine: #T,", ServiceTypes(pkgName).serviceEngine)
write("val region: String,")
write("val requestBodyLimit: Long,")
write("val requestReadTimeoutSeconds: Int,")
write("val responseWriteTimeoutSeconds: Int,")
Expand All @@ -111,6 +112,7 @@ internal abstract class AbstractStubGenerator(
write("")
write("val port: Int get() = backing?.port ?: notInitialised(#S)", "port")
write("val engine: #T get() = backing?.engine ?: notInitialised(#S)", ServiceTypes(pkgName).serviceEngine, "engine")
write("val region: String get() = backing?.region ?: notInitialised(#S)", "region")
write("val requestBodyLimit: Long get() = backing?.requestBodyLimit ?: notInitialised(#S)", "requestBodyLimit")
write("val requestReadTimeoutSeconds: Int get() = backing?.requestReadTimeoutSeconds ?: notInitialised(#S)", "requestReadTimeoutSeconds")
write("val responseWriteTimeoutSeconds: Int get() = backing?.responseWriteTimeoutSeconds ?: notInitialised(#S)", "responseWriteTimeoutSeconds")
Expand All @@ -121,6 +123,7 @@ internal abstract class AbstractStubGenerator(
withInlineBlock("fun init(", ")") {
write("port: Int,")
write("engine: #T,", ServiceTypes(pkgName).serviceEngine)
write("region: String,")
write("requestBodyLimit: Long,")
write("requestReadTimeoutSeconds: Int,")
write("responseWriteTimeoutSeconds: Int,")
Expand All @@ -130,7 +133,7 @@ internal abstract class AbstractStubGenerator(
}
withBlock("{", "}") {
write("check(backing == null) { #S }", "ServiceFrameworkConfig has already been initialised")
write("backing = Data(port, engine, requestBodyLimit, requestReadTimeoutSeconds, responseWriteTimeoutSeconds, closeGracePeriodMillis, closeTimeoutMillis, logLevel)")
write("backing = Data(port, engine, region, requestBodyLimit, requestReadTimeoutSeconds, responseWriteTimeoutSeconds, closeGracePeriodMillis, closeTimeoutMillis, logLevel)")
}
write("")
withBlock("private fun notInitialised(prop: String): Nothing {", "}") {
Expand Down Expand Up @@ -178,6 +181,7 @@ internal abstract class AbstractStubGenerator(
protected fun renderMainFile() {
val portName = "port"
val engineFactoryName = "engineFactory"
val regionName = "region"
val requestBodyLimitName = "requestBodyLimit"
val requestReadTimeoutSecondsName = "requestReadTimeoutSeconds"
val responseWriteTimeoutSecondsName = "responseWriteTimeoutSeconds"
Expand All @@ -191,6 +195,7 @@ internal abstract class AbstractStubGenerator(
write("")
write("val defaultPort = 8080")
write("val defaultEngine = #T.NETTY_ENGINE.value", ServiceTypes(pkgName).serviceEngine)
write("val defaultRegion = #S", "us-east-1")
write("val defaultRequestBodyLimit = 10L * 1024 * 1024")
write("val defaultRequestReadTimeoutSeconds = 30")
write("val defaultResponseWriteTimeoutSeconds = 30")
Expand All @@ -201,6 +206,7 @@ internal abstract class AbstractStubGenerator(
withBlock("#T.init(", ")", ServiceTypes(pkgName).serviceFrameworkConfig) {
write("port = argMap[#S]?.toInt() ?: defaultPort, ", portName)
write("engine = #T.fromValue(argMap[#S] ?: defaultEngine), ", ServiceTypes(pkgName).serviceEngine, engineFactoryName)
write("region = argMap[#S]?.toString() ?: defaultRegion, ", regionName)
write("requestBodyLimit = argMap[#S]?.toLong() ?: defaultRequestBodyLimit, ", requestBodyLimitName)
write("requestReadTimeoutSeconds = argMap[#S]?.toInt() ?: defaultRequestReadTimeoutSeconds, ", requestReadTimeoutSecondsName)
write("responseWriteTimeoutSeconds = argMap[#S]?.toInt() ?: defaultResponseWriteTimeoutSeconds, ", responseWriteTimeoutSecondsName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ internal class ConstraintGenerator(

private fun renderRequestConstraintsValidation() {
delegator.useFileWriter("${opName}RequestConstraints.kt", "$pkgName.constraints") { writer ->
writer.addImport("$pkgName.model", "${operation.id.name}Request")
val inputShape = ctx.model.expectShape(operation.input.get())
val inputSymbol = ctx.symbolProvider.toSymbol(inputShape)

writer.withBlock("public fun check${opName}RequestConstraint(data: ${opName}Request) {", "}") {
writer.withBlock("public fun check${opName}RequestConstraint(data: #T) {", "}", inputSymbol) {
for (memberShape in inputMembers.values) {
generateConstraintValidations("data.", memberShape, writer)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package software.amazon.smithy.kotlin.codegen.service.ktor

import software.amazon.smithy.aws.traits.auth.SigV4ATrait
import software.amazon.smithy.aws.traits.auth.SigV4Trait
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
import software.amazon.smithy.kotlin.codegen.core.withBlock
import software.amazon.smithy.kotlin.codegen.model.getTrait
import software.amazon.smithy.kotlin.codegen.service.KtorStubGenerator
import software.amazon.smithy.kotlin.codegen.service.ServiceTypes

internal fun KtorStubGenerator.writeAuthentication() {
delegator.useFileWriter("UserPrincipal.kt", "$pkgName.auth") { writer ->
writer.withBlock("public data class UserPrincipal(", ")") {
write("val user: String")
}
}

delegator.useFileWriter("Validation.kt", "$pkgName.auth") { writer ->

writer.withBlock("internal object BearerValidation {", "}") {
withBlock("public fun bearerValidation(token: String): UserPrincipal? {", "}") {
write("// TODO: implement me")
write("if (true) return UserPrincipal(#S) else return null", "Authenticated User")
}
}
}

delegator.useFileWriter("Authentication.kt", "$pkgName.auth") { writer ->
writer.withBlock("internal fun #T.configureAuthentication() {", "}", RuntimeTypes.KtorServerCore.Application) {
write("")
withBlock(
"#T(#T) {",
"}",
RuntimeTypes.KtorServerCore.install,
RuntimeTypes.KtorServerAuth.Authentication,
) {
withBlock("#T(#S) {", "}", RuntimeTypes.KtorServerAuth.bearer, "auth-bearer") {
write("realm = #S", "Access to API")
write("authenticate { cred -> BearerValidation.bearerValidation(cred.token) }")
}
withBlock("sigV4(name = #S) {", "}", "aws-sigv4") {
write("region = #T.region", ServiceTypes(pkgName).serviceFrameworkConfig)
val serviceSigV4AuthTrait = serviceShape.getTrait<SigV4Trait>()
if (serviceSigV4AuthTrait != null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be rendering the

withBlock("sigV4(name = #S) {", "}", "aws-sigv4") {

block from above if the service doesn't have the SigV4 trait?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think yes, because once generated, you should allow user to use sigv4 if they want to add it manually

write("service = #S", serviceSigV4AuthTrait.name)
}
}
withBlock("sigV4A(name = #S) {", "}", "aws-sigv4a") {
write("region = #T.region", ServiceTypes(pkgName).serviceFrameworkConfig)
val serviceSigV4AAuthTrait = serviceShape.getTrait<SigV4ATrait>()
if (serviceSigV4AAuthTrait != null) {
write("service = #S", serviceSigV4AAuthTrait.name)
}
}
write("provider(#S) { authenticate { ctx -> ctx.principal(Unit) } }", "no-auth")
}
}
}
}
Loading
Loading