Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug where Custom Auth Schemes were not respected #3087

Merged
merged 3 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,9 @@ message = "Clients now have a default async sleep implementation so that one doe
references = ["smithy-rs#3071"]
meta = { "breaking" = false, "tada" = true, "bug" = false, "target" = "client" }
author = "jdisanti"

[[smithy-rs]]
message = "Enable custom auth schemes to work by changing the code generated auth options to be set at the client level at `DEFAULTS` priority."
references = ["smithy-rs#3034", "smithy-rs#3087"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "client" }
author = "rcoh"
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,37 @@ import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.core.util.hasEventStreamOperations
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isInputEventStream
import software.amazon.smithy.rust.codegen.core.util.thenSingletonListOf

class SigV4AuthDecorator : ClientCodegenDecorator {
override val name: String get() = "SigV4AuthDecorator"
override val order: Byte = 0

private val sigv4a = "sigv4a"

private fun sigv4(runtimeConfig: RuntimeConfig) = writable {
val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth")
rust("#T", awsRuntimeAuthModule.resolve("sigv4::SCHEME_ID"))
}

private fun sigv4a(runtimeConfig: RuntimeConfig) = writable {
val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(runtimeConfig).resolve("auth")
featureGateBlock(sigv4a) {
rust("#T", awsRuntimeAuthModule.resolve("sigv4a::SCHEME_ID"))
}
}

override fun authOptions(
codegenContext: ClientCodegenContext,
operationShape: OperationShape,
baseAuthSchemeOptions: List<AuthSchemeOption>,
): List<AuthSchemeOption> = baseAuthSchemeOptions + AuthSchemeOption.StaticAuthSchemeOption(SigV4Trait.ID) {
val awsRuntimeAuthModule = AwsRuntimeType.awsRuntime(codegenContext.runtimeConfig).resolve("auth")
rust("#T,", awsRuntimeAuthModule.resolve("sigv4::SCHEME_ID"))
if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) {
featureGateBlock("sigv4a") {
rust("#T", awsRuntimeAuthModule.resolve("sigv4a::SCHEME_ID"))
}
rust(",")
}
): List<AuthSchemeOption> {
val supportsSigV4a = codegenContext.serviceShape.supportedAuthSchemes().contains(sigv4a)
.thenSingletonListOf { sigv4a(codegenContext.runtimeConfig) }
return baseAuthSchemeOptions + AuthSchemeOption.StaticAuthSchemeOption(
SigV4Trait.ID,
listOf(sigv4(codegenContext.runtimeConfig)) + supportsSigV4a,
)
}

override fun serviceRuntimePluginCustomizations(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ class HttpAuthDecorator : ClientCodegenDecorator {
options.add(
StaticAuthSchemeOption(
schemeShapeId,
writable {
rustTemplate("$name,", *codegenScope)
},
listOf(
writable {
rustTemplate(name, *codegenScope)
},
),
),
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.customize.AuthSchemeOpt
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType

val noAuthSchemeShapeId: ShapeId = ShapeId.from("aws.smithy.rs#NoAuth")
Expand All @@ -30,10 +31,15 @@ class NoAuthDecorator : ClientCodegenDecorator {
operationShape: OperationShape,
baseAuthSchemeOptions: List<AuthSchemeOption>,
): List<AuthSchemeOption> = baseAuthSchemeOptions +
AuthSchemeOption.StaticAuthSchemeOption(noAuthSchemeShapeId) {
rustTemplate(
"#{NO_AUTH_SCHEME_ID},",
"NO_AUTH_SCHEME_ID" to noAuthModule(codegenContext).resolve("NO_AUTH_SCHEME_ID"),
)
}
AuthSchemeOption.StaticAuthSchemeOption(
noAuthSchemeShapeId,
listOf(
writable {
rustTemplate(
"#{NO_AUTH_SCHEME_ID}",
"NO_AUTH_SCHEME_ID" to noAuthModule(codegenContext).resolve("NO_AUTH_SCHEME_ID"),
)
},
),
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ sealed interface AuthSchemeOption {
/** Auth scheme for the `StaticAuthSchemeOptionResolver` */
data class StaticAuthSchemeOption(
val schemeShapeId: ShapeId,
val constructor: Writable,
val constructor: List<Writable>,
) : AuthSchemeOption

class CustomResolver(/* unimplemented */) : AuthSchemeOption
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.model.knowledge.ServiceIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.OptionalAuthTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customizations.noAuthSchemeShapeId
import software.amazon.smithy.rust.codegen.client.smithy.customize.AuthSchemeOption
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.isEmpty
import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import java.util.logging.Logger

class AuthOptionsPluginGenerator(private val codegenContext: ClientCodegenContext) {
private val runtimeConfig = codegenContext.runtimeConfig

private val logger: Logger = Logger.getLogger(javaClass.name)
fun authPlugin(operationShape: OperationShape, authSchemeOptions: List<AuthSchemeOption>) = writable {
rustTemplate(
"""
#{DefaultAuthOptionsPlugin}::new(vec![#{options}])

""",
"DefaultAuthOptionsPlugin" to RuntimeType.defaultAuthPlugin(runtimeConfig),
"options" to actualAuthSchemes(operationShape, authSchemeOptions).join(", "),
)
}

private fun actualAuthSchemes(operationShape: OperationShape, authSchemeOptions: List<AuthSchemeOption>): List<Writable> {
val out: MutableList<Writable> = mutableListOf()

var noSupportedAuthSchemes = true
val authSchemes = ServiceIndex.of(codegenContext.model)
.getEffectiveAuthSchemes(codegenContext.serviceShape, operationShape)

for (schemeShapeId in authSchemes.keys) {
val optionsForScheme = authSchemeOptions.filter {
when (it) {
is AuthSchemeOption.CustomResolver -> false
is AuthSchemeOption.StaticAuthSchemeOption -> {
it.schemeShapeId == schemeShapeId
}
}
}

if (optionsForScheme.isNotEmpty()) {
out.addAll(optionsForScheme.flatMap { (it as AuthSchemeOption.StaticAuthSchemeOption).constructor })
noSupportedAuthSchemes = false
} else {
logger.warning(
"No auth scheme implementation available for $schemeShapeId. " +
"The generated client will not attempt to use this auth scheme.",
)
}
}
if (operationShape.hasTrait<OptionalAuthTrait>() || noSupportedAuthSchemes) {
val authOption = authSchemeOptions.find {
it is AuthSchemeOption.StaticAuthSchemeOption && it.schemeShapeId == noAuthSchemeShapeId
}
?: throw IllegalStateException("Missing 'no auth' implementation. This is a codegen bug.")
out += (authOption as AuthSchemeOption.StaticAuthSchemeOption).constructor
}
if (out.any { it.isEmpty() }) {
PANIC("Got an empty auth scheme constructor. This is a bug. $out")
}
return out
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ open class OperationGenerator(
operationShape: OperationShape,
codegenDecorator: ClientCodegenDecorator,
) {
val operationCustomizations = codegenDecorator.operationCustomizations(codegenContext, operationShape, emptyList())
val operationCustomizations =
codegenDecorator.operationCustomizations(codegenContext, operationShape, emptyList())
renderOperationStruct(
operationWriter,
operationShape,
Expand Down Expand Up @@ -94,14 +95,22 @@ open class OperationGenerator(
"Operation" to symbolProvider.toSymbol(operationShape),
"OperationError" to errorType,
"OperationOutput" to outputType,
"HttpResponse" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::orchestrator::HttpResponse"),
"HttpResponse" to RuntimeType.smithyRuntimeApi(runtimeConfig)
.resolve("client::orchestrator::HttpResponse"),
"SdkError" to RuntimeType.sdkError(runtimeConfig),
)
val additionalPlugins = writable {
writeCustomizations(
operationCustomizations,
OperationSection.AdditionalRuntimePlugins(operationCustomizations, operationShape),
)
rustTemplate(
".with_client_plugin(#{auth_plugin})",
"auth_plugin" to AuthOptionsPluginGenerator(codegenContext).authPlugin(
operationShape,
authSchemeOptions,
),
)
}
rustTemplate(
"""
Expand Down Expand Up @@ -157,11 +166,13 @@ open class OperationGenerator(
*codegenScope,
"Error" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::interceptors::context::Error"),
"InterceptorContext" to RuntimeType.interceptorContext(runtimeConfig),
"OrchestratorError" to RuntimeType.smithyRuntimeApi(runtimeConfig).resolve("client::orchestrator::error::OrchestratorError"),
"OrchestratorError" to RuntimeType.smithyRuntimeApi(runtimeConfig)
.resolve("client::orchestrator::error::OrchestratorError"),
"RuntimePlugin" to RuntimeType.runtimePlugin(runtimeConfig),
"RuntimePlugins" to RuntimeType.runtimePlugins(runtimeConfig),
"StopPoint" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::StopPoint"),
"invoke_with_stop_point" to RuntimeType.smithyRuntime(runtimeConfig).resolve("client::orchestrator::invoke_with_stop_point"),
"invoke_with_stop_point" to RuntimeType.smithyRuntime(runtimeConfig)
.resolve("client::orchestrator::invoke_with_stop_point"),
"additional_runtime_plugins" to writable {
if (additionalPlugins.isNotEmpty()) {
rustTemplate(
Expand All @@ -182,7 +193,6 @@ open class OperationGenerator(
operationWriter,
operationShape,
operationName,
authSchemeOptions,
operationCustomizations,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,22 @@

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.model.knowledge.ServiceIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.OptionalAuthTrait
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.customizations.noAuthSchemeShapeId
import software.amazon.smithy.rust.codegen.client.smithy.customize.AuthSchemeOption
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import java.util.logging.Logger

/**
* Generates operation-level runtime plugins
*/
class OperationRuntimePluginGenerator(
private val codegenContext: ClientCodegenContext,
) {
private val logger: Logger = Logger.getLogger(javaClass.name)
private val codegenScope = codegenContext.runtimeConfig.let { rc ->
val runtimeApi = RuntimeType.smithyRuntimeApi(rc)
val smithyTypes = RuntimeType.smithyTypes(rc)
Expand Down Expand Up @@ -57,7 +48,6 @@ class OperationRuntimePluginGenerator(
writer: RustWriter,
operationShape: OperationShape,
operationStructName: String,
authSchemeOptions: List<AuthSchemeOption>,
customizations: List<OperationCustomization>,
) {
writer.rustTemplate(
Expand All @@ -80,7 +70,6 @@ class OperationRuntimePluginGenerator(
fn runtime_components(&self, _: &#{RuntimeComponentsBuilder}) -> #{Cow}<'_, #{RuntimeComponentsBuilder}> {
#{Cow}::Owned(
#{RuntimeComponentsBuilder}::new(${operationShape.id.name.dq()})
#{auth_options}
#{interceptors}
#{retry_classifiers}
)
Expand All @@ -91,7 +80,6 @@ class OperationRuntimePluginGenerator(
""",
*codegenScope,
*preludeScope,
"auth_options" to generateAuthOptions(operationShape, authSchemeOptions),
"additional_config" to writable {
writeCustomizations(
customizations,
Expand Down Expand Up @@ -122,55 +110,4 @@ class OperationRuntimePluginGenerator(
},
)
}

private fun generateAuthOptions(
operationShape: OperationShape,
authSchemeOptions: List<AuthSchemeOption>,
): Writable = writable {
if (authSchemeOptions.any { it is AuthSchemeOption.CustomResolver }) {
throw IllegalStateException("AuthSchemeOption.CustomResolver is unimplemented")
} else {
withBlockTemplate(
"""
.with_auth_scheme_option_resolver(#{Some}(
#{SharedAuthSchemeOptionResolver}::new(
#{StaticAuthSchemeOptionResolver}::new(vec![
""",
"]))))",
*codegenScope,
) {
var noSupportedAuthSchemes = true
val authSchemes = ServiceIndex.of(codegenContext.model)
.getEffectiveAuthSchemes(codegenContext.serviceShape, operationShape)

for (schemeShapeId in authSchemes.keys) {
val optionsForScheme = authSchemeOptions.filter {
when (it) {
is AuthSchemeOption.CustomResolver -> false
is AuthSchemeOption.StaticAuthSchemeOption -> {
it.schemeShapeId == schemeShapeId
}
}
}

if (optionsForScheme.isNotEmpty()) {
optionsForScheme.forEach { (it as AuthSchemeOption.StaticAuthSchemeOption).constructor(this) }
noSupportedAuthSchemes = false
} else {
logger.warning(
"No auth scheme implementation available for $schemeShapeId. " +
"The generated client will not attempt to use this auth scheme.",
)
}
}
if (operationShape.hasTrait<OptionalAuthTrait>() || noSupportedAuthSchemes) {
val authOption = authSchemeOptions.find {
it is AuthSchemeOption.StaticAuthSchemeOption && it.schemeShapeId == noAuthSchemeShapeId
}
?: throw IllegalStateException("Missing 'no auth' implementation. This is a codegen bug.")
(authOption as AuthSchemeOption.StaticAuthSchemeOption).constructor(this)
}
}
}
}
}
Loading