Skip to content

Commit

Permalink
Fix auth scheme issue by registering auth schemes as a default-level …
Browse files Browse the repository at this point in the history
…client plugin
  • Loading branch information
rcoh committed Oct 23, 2023
1 parent b471b46 commit d9761d3
Show file tree
Hide file tree
Showing 16 changed files with 256 additions and 107 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,9 @@ message = "Clients now have a default async sleep implementation so that one doe
references = ["smithy-rs#3071"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client" }
author = "jdisanti"

[[smithy-rs]]
message = "Enable custom auth schemes to work by changing the code generated auth options to be set at the client level at `DEFAULTS` priority."
references = ["smithy-rs#3034", "smithy-rs#3087"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" }
author = "rcoh"
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,35 @@ import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasEventStreamOperations
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isInputEventStream
import software.amazon.smithy.rust.codegen.core.util.thenSingletonListOf

class SigV4AuthDecorator : ClientCodegenDecorator {
override val name: String get() = "SigV4AuthDecorator"
override val order: Byte = 0

private fun sigv4(runtimeConfig: RuntimeConfig) = writable {
val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth")
rust("#T", awsRuntimeAuthModule.resolve("sigv4::SCHEME_ID"))
}

private fun sigv4a(runtimeConfig: RuntimeConfig) = writable {
val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth")
featureGateBlock("sigv4a") {
rust("#T", awsRuntimeAuthModule.resolve("sigv4a::SCHEME_ID"))
}
}

override fun authOptions(
codegenContext: ClientCodegenContext,
operationShape: OperationShape,
baseAuthSchemeOptions: List<AuthSchemeOption>,
): List<AuthSchemeOption> = baseAuthSchemeOptions + AuthSchemeOption.StaticAuthSchemeOption(SigV4Trait.ID) {
val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(codegenContext.runtimeConfig).resolve("auth")
rust("#T,", awsRuntimeAuthModule.resolve("sigv4::SCHEME_ID"))
if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) {
featureGateBlock("sigv4a") {
rust("#T", awsRuntimeAuthModule.resolve("sigv4a::SCHEME_ID"))
}
rust(",")
}
): List<AuthSchemeOption> {
val supportsSigV4a = codegenContext.serviceShape.supportedAuthSchemes().contains("sig4a")
.thenSingletonListOf { sigv4a(codegenContext.runtimeConfig) }
return baseAuthSchemeOptions + AuthSchemeOption.StaticAuthSchemeOption(
SigV4Trait.ID,
listOf(sigv4(codegenContext.runtimeConfig)) + supportsSigV4a,
)
}

override fun serviceRuntimePluginCustomizations(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ class HttpAuthDecorator : ClientCodegenDecorator {
options.add(
StaticAuthSchemeOption(
schemeShapeId,
writable {
rustTemplate("$name,", *codegenScope)
},
listOf(
writable {
rustTemplate(name, *codegenScope)
},
),
),
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.AuthSchemeOpt
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType

val noAuthSchemeShapeId: ShapeId = ShapeId.from("aws.smithy.rs#NoAuth")
Expand All @@ -30,10 +31,15 @@ class NoAuthDecorator : ClientCodegenDecorator {
operationShape: OperationShape,
baseAuthSchemeOptions: List<AuthSchemeOption>,
): List<AuthSchemeOption> = baseAuthSchemeOptions +
AuthSchemeOption.StaticAuthSchemeOption(noAuthSchemeShapeId) {
rustTemplate(
"#{NO_AUTH_SCHEME_ID},",
"NO_AUTH_SCHEME_ID" to noAuthModule(codegenContext).resolve("NO_AUTH_SCHEME_ID"),
)
}
AuthSchemeOption.StaticAuthSchemeOption(
noAuthSchemeShapeId,
listOf(
writable {
rustTemplate(
"#{NO_AUTH_SCHEME_ID}",
"NO_AUTH_SCHEME_ID" to noAuthModule(codegenContext).resolve("NO_AUTH_SCHEME_ID"),
)
},
),
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ sealed interface AuthSchemeOption {
/** Auth scheme for the `StaticAuthSchemeOptionResolver` */
data class StaticAuthSchemeOption(
val schemeShapeId: ShapeId,
val constructor: Writable,
val constructor: List<Writable>,
) : AuthSchemeOption

class CustomResolver(/* unimplemented */) : AuthSchemeOption
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.model.knowledge.ServiceIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.OptionalAuthTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customizations.noAuthSchemeShapeId
import software.amazon.smithy.rust.codegen.client.smithy.customize.AuthSchemeOption
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.isEmpty
import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import java.util.logging.Logger

class AuthOptionsPluginGenerator(private val codegenContext: ClientCodegenContext) {
private val runtimeConfig = codegenContext.runtimeConfig

private val logger: Logger = Logger.getLogger(javaClass.name)
fun authPlugin(operationShape: OperationShape, authSchemeOptions: List<AuthSchemeOption>) = writable {
rustTemplate(
"""
#{DefaultAuthOptionsPlugin}::new(vec![#{options}])
""",
"DefaultAuthOptionsPlugin" to RuntimeType.defaultAuthPlugin(runtimeConfig),
"options" to actualAuthSchemes(operationShape, authSchemeOptions).join(", "),
)
}

private fun actualAuthSchemes(operationShape: OperationShape, authSchemeOptions: List<AuthSchemeOption>): List<Writable> {
val out: MutableList<Writable> = mutableListOf()

var noSupportedAuthSchemes = true
val authSchemes = ServiceIndex.of(codegenContext.model)
.getEffectiveAuthSchemes(codegenContext.serviceShape, operationShape)

for (schemeShapeId in authSchemes.keys) {
val optionsForScheme = authSchemeOptions.filter {
when (it) {
is AuthSchemeOption.CustomResolver -> false
is AuthSchemeOption.StaticAuthSchemeOption -> {
it.schemeShapeId == schemeShapeId
}
}
}

if (optionsForScheme.isNotEmpty()) {
out.addAll(optionsForScheme.flatMap { (it as AuthSchemeOption.StaticAuthSchemeOption).constructor })
noSupportedAuthSchemes = false
} else {
logger.warning(
"No auth scheme implementation available for $schemeShapeId. " +
"The generated client will not attempt to use this auth scheme.",
)
}
}
if (operationShape.hasTrait<OptionalAuthTrait>() || noSupportedAuthSchemes) {
val authOption = authSchemeOptions.find {
it is AuthSchemeOption.StaticAuthSchemeOption && it.schemeShapeId == noAuthSchemeShapeId
}
?: throw IllegalStateException("Missing 'no auth' implementation. This is a codegen bug.")
out += (authOption as AuthSchemeOption.StaticAuthSchemeOption).constructor
}
if (out.any { it.isEmpty() }) {
PANIC("Got an empty auth scheme constructor. This is a bug. $out")
}
return out
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ open class OperationGenerator(
operationShape: OperationShape,
codegenDecorator: ClientCodegenDecorator,
) {
val operationCustomizations = codegenDecorator.operationCustomizations(codegenContext, operationShape, emptyList())
val operationCustomizations =
codegenDecorator.operationCustomizations(codegenContext, operationShape, emptyList())
renderOperationStruct(
operationWriter,
operationShape,
Expand Down Expand Up @@ -94,14 +95,22 @@ open class OperationGenerator(
"Operation" to symbolProvider.toSymbol(operationShape),
"OperationError" to errorType,
"OperationOutput" to outputType,
"HttpResponse" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::orchestrator::HttpResponse"),
"HttpResponse" to RuntimeType.smithyRuntimeApi(runtimeConfig)
.resolve("client::orchestrator::HttpResponse"),
"SdkError" to RuntimeType.sdkError(runtimeConfig),
)
val additionalPlugins = writable {
writeCustomizations(
operationCustomizations,
OperationSection.AdditionalRuntimePlugins(operationCustomizations, operationShape),
)
rustTemplate(
".with_client_plugin(#{auth_plugin})",
"auth_plugin" to AuthOptionsPluginGenerator(codegenContext).authPlugin(
operationShape,
authSchemeOptions,
),
)
}
rustTemplate(
"""
Expand Down Expand Up @@ -157,11 +166,13 @@ open class OperationGenerator(
*codegenScope,
"Error" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors::context::Error"),
"InterceptorContext" to RuntimeType.interceptorContext(runtimeConfig),
"OrchestratorError" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::orchestrator::error::OrchestratorError"),
"OrchestratorError" to RuntimeType.smithyRuntimeApi(runtimeConfig)
.resolve("client::orchestrator::error::OrchestratorError"),
"RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig),
"RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig),
"StopPoint" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::StopPoint"),
"invoke_with_stop_point" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::invoke_with_stop_point"),
"invoke_with_stop_point" to RuntimeType.smithyRuntime(runtimeConfig)
.resolve("client::orchestrator::invoke_with_stop_point"),
"additional_runtime_plugins" to writable {
if (additionalPlugins.isNotEmpty()) {
rustTemplate(
Expand All @@ -182,7 +193,6 @@ open class OperationGenerator(
operationWriter,
operationShape,
operationName,
authSchemeOptions,
operationCustomizations,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,22 @@

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.model.knowledge.ServiceIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.OptionalAuthTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customizations.noAuthSchemeShapeId
import software.amazon.smithy.rust.codegen.client.smithy.customize.AuthSchemeOption
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import java.util.logging.Logger

/**
* Generates operation-level runtime plugins
*/
class OperationRuntimePluginGenerator(
private val codegenContext: ClientCodegenContext,
) {
private val logger: Logger = Logger.getLogger(javaClass.name)
private val codegenScope = codegenContext.runtimeConfig.let { rc ->
val runtimeApi = RuntimeType.smithyRuntimeApi(rc)
val smithyTypes = RuntimeType.smithyTypes(rc)
Expand Down Expand Up @@ -57,7 +48,6 @@ class OperationRuntimePluginGenerator(
writer: RustWriter,
operationShape: OperationShape,
operationStructName: String,
authSchemeOptions: List<AuthSchemeOption>,
customizations: List<OperationCustomization>,
) {
writer.rustTemplate(
Expand All @@ -80,7 +70,6 @@ class OperationRuntimePluginGenerator(
fn runtime_components(&self, _: &#{RuntimeComponentsBuilder}) -> #{Cow}<'_, #{RuntimeComponentsBuilder}> {
#{Cow}::Owned(
#{RuntimeComponentsBuilder}::new(${operationShape.id.name.dq()})
#{auth_options}
#{interceptors}
#{retry_classifiers}
)
Expand All @@ -91,7 +80,6 @@ class OperationRuntimePluginGenerator(
""",
*codegenScope,
*preludeScope,
"auth_options" to generateAuthOptions(operationShape, authSchemeOptions),
"additional_config" to writable {
writeCustomizations(
customizations,
Expand Down Expand Up @@ -122,55 +110,4 @@ class OperationRuntimePluginGenerator(
},
)
}

private fun generateAuthOptions(
operationShape: OperationShape,
authSchemeOptions: List<AuthSchemeOption>,
): Writable = writable {
if (authSchemeOptions.any { it is AuthSchemeOption.CustomResolver }) {
throw IllegalStateException("AuthSchemeOption.CustomResolver is unimplemented")
} else {
withBlockTemplate(
"""
.with_auth_scheme_option_resolver(#{Some}(
#{SharedAuthSchemeOptionResolver}::new(
#{StaticAuthSchemeOptionResolver}::new(vec![
""",
"]))))",
*codegenScope,
) {
var noSupportedAuthSchemes = true
val authSchemes = ServiceIndex.of(codegenContext.model)
.getEffectiveAuthSchemes(codegenContext.serviceShape, operationShape)

for (schemeShapeId in authSchemes.keys) {
val optionsForScheme = authSchemeOptions.filter {
when (it) {
is AuthSchemeOption.CustomResolver -> false
is AuthSchemeOption.StaticAuthSchemeOption -> {
it.schemeShapeId == schemeShapeId
}
}
}

if (optionsForScheme.isNotEmpty()) {
optionsForScheme.forEach { (it as AuthSchemeOption.StaticAuthSchemeOption).constructor(this) }
noSupportedAuthSchemes = false
} else {
logger.warning(
"No auth scheme implementation available for $schemeShapeId. " +
"The generated client will not attempt to use this auth scheme.",
)
}
}
if (operationShape.hasTrait<OptionalAuthTrait>() || noSupportedAuthSchemes) {
val authOption = authSchemeOptions.find {
it is AuthSchemeOption.StaticAuthSchemeOption && it.schemeShapeId == noAuthSchemeShapeId
}
?: throw IllegalStateException("Missing 'no auth' implementation. This is a codegen bug.")
(authOption as AuthSchemeOption.StaticAuthSchemeOption).constructor(this)
}
}
}
}
}
Loading

0 comments on commit d9761d3

Please sign in to comment.