From 18fbd38fe599a45fd4a025d256731dbcb21bb813 Mon Sep 17 00:00:00 2001 From: Janet Gainer-Dewar Date: Fri, 19 Jul 2024 19:11:19 -0400 Subject: [PATCH] WX-1594 Support for private ACR access in Terra (#715) Co-authored-by: Blair L Murri --- src/CommonUtilities/Models/NodeTask.cs | 1 + src/Tes.ApiClients.Tests/TerraApiStubData.cs | 28 ++++++ .../TerraSamApiClientTests.cs | 90 +++++++++++++++++++ src/Tes.ApiClients/HttpApiClient.cs | 28 +++++- .../SamActionManagedIdentityApiResponse.cs | 61 +++++++++++++ src/Tes.ApiClients/TerraSamApiClient.cs | 87 ++++++++++++++++++ ...tainerRegistryAuthorizationManagerTests.cs | 2 +- .../Docker/DockerExecutorTests.cs | 2 +- .../Authentication/CredentialsManager.cs | 30 +++++-- .../ContainerRegistryAuthorizationManager.cs | 5 +- src/TesApi.Tests/BatchSchedulerTests.cs | 41 ++++----- .../Runner/TaskToNodeTaskConverterTests.cs | 45 ++++++++-- src/TesApi.Tests/StartupTests.cs | 42 +++++++-- .../TerraActionIdentityProviderTests.cs | 79 ++++++++++++++++ src/TesApi.Tests/TerraApiStubData.cs | 37 +++++++- .../TestServices/TestServiceProvider.cs | 1 + src/TesApi.Web/AzureProxy.cs | 9 ++ src/TesApi.Web/BatchScheduler.cs | 46 +++++++--- .../CachingWithRetriesAzureProxy.cs | 3 + .../DefaultActionIdentityProvider.cs | 22 +++++ src/TesApi.Web/IActionIdentityProvider.cs | 23 +++++ src/TesApi.Web/IAzureProxy.cs | 7 ++ .../Management/Configuration/TerraOptions.cs | 18 +++- src/TesApi.Web/Runner/NodeTaskBuilder.cs | 28 +++++- .../Runner/TaskExecutionScriptingManager.cs | 4 +- .../Runner/TaskToNodeTaskConverter.cs | 65 +++++++++++--- src/TesApi.Web/Startup.cs | 26 ++++++ src/TesApi.Web/TerraActionIdentityProvider.cs | 71 +++++++++++++++ 28 files changed, 822 insertions(+), 79 deletions(-) create mode 100644 src/Tes.ApiClients.Tests/TerraSamApiClientTests.cs create mode 100644 src/Tes.ApiClients/Models/Terra/SamActionManagedIdentityApiResponse.cs create mode 100644 src/Tes.ApiClients/TerraSamApiClient.cs create mode 100644 src/TesApi.Tests/TerraActionIdentityProviderTests.cs create mode 100644 src/TesApi.Web/DefaultActionIdentityProvider.cs create mode 100644 src/TesApi.Web/IActionIdentityProvider.cs create mode 100644 src/TesApi.Web/TerraActionIdentityProvider.cs diff --git a/src/CommonUtilities/Models/NodeTask.cs b/src/CommonUtilities/Models/NodeTask.cs index 1b7126456..d133cb010 100644 --- a/src/CommonUtilities/Models/NodeTask.cs +++ b/src/CommonUtilities/Models/NodeTask.cs @@ -46,6 +46,7 @@ public class RuntimeOptions public TerraRuntimeOptions? Terra { get; set; } public string? NodeManagedIdentityResourceId { get; set; } + public string? AcrPullManagedIdentityResourceId { get; set; } public StorageTargetLocation? StorageEventSink { get; set; } diff --git a/src/Tes.ApiClients.Tests/TerraApiStubData.cs b/src/Tes.ApiClients.Tests/TerraApiStubData.cs index cf350fcdd..18b51edf8 100644 --- a/src/Tes.ApiClients.Tests/TerraApiStubData.cs +++ b/src/Tes.ApiClients.Tests/TerraApiStubData.cs @@ -10,6 +10,7 @@ public class TerraApiStubData { public const string LandingZoneApiHost = "https://landingzone.host"; public const string WsmApiHost = "https://wsm.host"; + public const string SamApiHost = "https://sam.host"; public const string ResourceGroup = "mrg-terra-dev-previ-20191228"; public const string WorkspaceAccountName = "lzaccount1"; public const string SasToken = "SASTOKENSTUB="; @@ -18,8 +19,12 @@ public class TerraApiStubData public const string WorkspaceStorageContainerName = $"sc-{WorkspaceIdValue}"; public const string WsmGetSasResponseStorageUrl = $"https://{WorkspaceAccountName}.blob.core.windows.net/{WorkspaceStorageContainerName}"; + public const string TerraPetName = "pet-2674060218359759651b0"; + + public Guid TenantId { get; } = Guid.NewGuid(); public Guid LandingZoneId { get; } = Guid.NewGuid(); public Guid SubscriptionId { get; } = Guid.NewGuid(); + public Guid AcrPullIdentitySamResourceId { get; } = Guid.NewGuid(); public Guid ContainerResourceId { get; } = Guid.NewGuid(); public Guid WorkspaceId { get; } = Guid.Parse(WorkspaceIdValue); @@ -28,6 +33,8 @@ public class TerraApiStubData public string BatchAccountId => $"/subscriptions/{SubscriptionId}/resourceGroups/{ResourceGroup}/providers/Microsoft.Batch/batchAccounts/{BatchAccountName}"; + public string ManagedIdentityObjectId => + $"/subscriptions/{SubscriptionId}/resourceGroups/{ResourceGroup}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/{TerraPetName}"; public string PoolId => "poolId"; public Guid GetWorkspaceIdFromContainerName(string containerName) @@ -279,6 +286,27 @@ public string GetResourceQuotaApiResponseInJson() }}"; } + public string GetSamActionManagedIdentityApiResponseInJson() + { + return $@"{{ + ""id"": {{ + ""resourceId"": {{ + ""resourceTypeName"": ""private_azure_container_registry"", + ""resourceId"": ""{AcrPullIdentitySamResourceId}"" + }}, + ""action"": ""pull_image"", + ""billingProfileId"": ""{AcrPullIdentitySamResourceId}"" + }}, + ""objectId"": ""{ManagedIdentityObjectId}"", + ""displayName"": ""my nice action identity"", + ""managedResourceGroupCoordinates"": {{ + ""tenantId"": ""{TenantId}"", + ""subscriptionId"": ""{SubscriptionId}"", + ""managedResourceGroupName"": ""{ResourceGroup}"" + }} +}}"; + } + public ApiCreateBatchPoolRequest GetApiCreateBatchPoolRequest() { return new ApiCreateBatchPoolRequest() diff --git a/src/Tes.ApiClients.Tests/TerraSamApiClientTests.cs b/src/Tes.ApiClients.Tests/TerraSamApiClientTests.cs new file mode 100644 index 000000000..372ec82e1 --- /dev/null +++ b/src/Tes.ApiClients.Tests/TerraSamApiClientTests.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Core; +using CommonUtilities; +using Microsoft.Extensions.Logging.Abstractions; +using Moq; +using Tes.ApiClients.Models.Terra; + +namespace Tes.ApiClients.Tests +{ + [TestClass, TestCategory("Unit")] + public class TerraSamApiClientTests + { + private TerraSamApiClient terraSamApiClient = null!; + private Mock tokenCredential = null!; + private Mock cacheAndRetryBuilder = null!; + private Lazy>> cacheAndRetryHandler = null!; + private TerraApiStubData terraApiStubData = null!; + private AzureEnvironmentConfig azureEnvironmentConfig = null!; + private TimeSpan cacheTTL = TimeSpan.FromMinutes(1); + + [TestInitialize] + public void SetUp() + { + terraApiStubData = new TerraApiStubData(); + tokenCredential = new Mock(); + cacheAndRetryBuilder = new Mock(); + var cache = new Mock(); + cache.Setup(c => c.CreateEntry(It.IsAny())).Returns(new Mock().Object); + cacheAndRetryBuilder.SetupGet(c => c.AppCache).Returns(cache.Object); + cacheAndRetryHandler = new(TestServices.RetryHandlersHelpers.GetCachingAsyncRetryPolicyMock(cacheAndRetryBuilder, c => c.DefaultRetryHttpResponseMessagePolicyBuilder())); + azureEnvironmentConfig = ExpensiveObjectTestUtility.AzureCloudConfig.AzureEnvironmentConfig!; + + terraSamApiClient = new TerraSamApiClient(TerraApiStubData.SamApiHost, tokenCredential.Object, + cacheAndRetryBuilder.Object, azureEnvironmentConfig, NullLogger.Instance); + } + + [TestMethod] + public async Task GetActionManagedIdentityAsync_ValidRequest_ReturnsPayload() + { + cacheAndRetryHandler.Value.Setup(c => c.ExecuteWithRetryConversionAndCachingAsync( + It.IsAny(), + It.IsAny>>(), + It.IsAny>>(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(System.Text.Json.JsonSerializer.Deserialize(terraApiStubData.GetSamActionManagedIdentityApiResponseInJson())!); + + var apiResponse = await terraSamApiClient.GetActionManagedIdentityForACRPullAsync(terraApiStubData.AcrPullIdentitySamResourceId, cacheTTL, CancellationToken.None); + + Assert.IsNotNull(apiResponse); + Assert.IsTrue(!string.IsNullOrEmpty(apiResponse.ObjectId)); + Assert.IsTrue(apiResponse.ObjectId.Contains(TerraApiStubData.TerraPetName)); + } + + [TestMethod] + public async Task GetActionManagedIdentityAsync_ValidRequest_Returns404() + { + cacheAndRetryHandler.Value.Setup(c => c.ExecuteWithRetryConversionAndCachingAsync( + It.IsAny(), + It.IsAny>>(), + It.IsAny>>(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Throws(new HttpRequestException(null, null, System.Net.HttpStatusCode.NotFound)); + + var apiResponse = await terraSamApiClient.GetActionManagedIdentityForACRPullAsync(terraApiStubData.AcrPullIdentitySamResourceId, cacheTTL, CancellationToken.None); + + Assert.IsNull(apiResponse); + } + + [TestMethod] + public async Task GetActionManagedIdentityAsync_ValidRequest_Returns500() + { + cacheAndRetryHandler.Value.Setup(c => c.ExecuteWithRetryConversionAndCachingAsync( + It.IsAny(), + It.IsAny>>(), + It.IsAny>>(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Throws(new HttpRequestException(null, null, System.Net.HttpStatusCode.BadGateway)); + + await Assert.ThrowsExceptionAsync(async () => await terraSamApiClient.GetActionManagedIdentityForACRPullAsync(terraApiStubData.AcrPullIdentitySamResourceId, cacheTTL, CancellationToken.None)); + } + } +} diff --git a/src/Tes.ApiClients/HttpApiClient.cs b/src/Tes.ApiClients/HttpApiClient.cs index 782477e3b..a7486b391 100644 --- a/src/Tes.ApiClients/HttpApiClient.cs +++ b/src/Tes.ApiClients/HttpApiClient.cs @@ -162,6 +162,32 @@ protected async Task HttpGetRequestWithCachingAndRetryPolicyAsync GetApiResponseContentAsync(m, typeInfo, ct), cancellationToken))!; } + /// + /// Checks the cache and if the request was not found, sends the GET request with a retry policy. + /// If the GET request is successful, adds it to the cache with the specified TTL. + /// + /// + /// JSON serialization-related metadata. + /// Time after which a newly-added entry will expire from the cache. + /// A for controlling the lifetime of the asynchronous operation. + /// If true, the authentication header is set with an authentication token. + /// Response's content deserialization type. + /// + protected async Task HttpGetRequestWithExpirableCachingAndRetryPolicyAsync(Uri requestUrl, + JsonTypeInfo typeInfo, TimeSpan cacheTTL, CancellationToken cancellationToken, bool setAuthorizationHeader = false) + { + var cacheKey = await ToCacheKeyAsync(requestUrl, setAuthorizationHeader, cancellationToken); + + return (await cachingRetryHandler.ExecuteWithRetryConversionAndCachingAsync(cacheKey, async ct => + { + //request must be recreated in every retry. + var httpRequest = await CreateGetHttpRequest(requestUrl, setAuthorizationHeader, ct); + + return await HttpClient.SendAsync(httpRequest, ct); + }, + (m, ct) => GetApiResponseContentAsync(m, typeInfo, ct), DateTimeOffset.Now + cacheTTL, cancellationToken))!; + } + /// /// Get request with retry policy /// @@ -223,7 +249,7 @@ private async Task CreateGetHttpRequest(Uri requestUrl, bool } /// - /// Sends an Http request to the URL and deserializes the body response to the specified type + /// Sends an Http request to the URL and deserializes the body response to the specified type /// /// Factory that creates new http requests, in the event of retry the factory is called again /// and must be idempotent. diff --git a/src/Tes.ApiClients/Models/Terra/SamActionManagedIdentityApiResponse.cs b/src/Tes.ApiClients/Models/Terra/SamActionManagedIdentityApiResponse.cs new file mode 100644 index 000000000..912139528 --- /dev/null +++ b/src/Tes.ApiClients/Models/Terra/SamActionManagedIdentityApiResponse.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Tes.ApiClients.Models.Terra +{ + public class SamActionManagedIdentityApiResponse + { + [JsonPropertyName("id")] + public ActionManagedIdentityId actionManagedIdentityId { get; set; } + + [JsonPropertyName("displayName")] + public string DisplayName { get; set; } + + [JsonPropertyName("managedResourceGroupCoordinates")] + public ManagedResourceGroupCoordinates managedResourceGroupCoordinates { get; set; } + + [JsonPropertyName("objectId")] + public string ObjectId { get; set; } + + } + + public class ActionManagedIdentityId + { + [JsonPropertyName("resourceId")] + public FullyQualifiedResourceId ResourceId { get; set; } + + [JsonPropertyName("action")] + public string Action { get; set; } + + [JsonPropertyName("billingProfileId")] + public Guid BillingProfileId { get; set; } + } + + public class FullyQualifiedResourceId + { + [JsonPropertyName("resourceTypeName")] + public string ResourceTypeName { get; set; } + + [JsonPropertyName("resourceId")] + public string ResourceId { get; set; } + } + + public class ManagedResourceGroupCoordinates + { + [JsonPropertyName("tenantId")] + public Guid TenantId { get; set; } + + [JsonPropertyName("subscriptionId")] + public Guid SubscriptionId { get; set; } + + [JsonPropertyName("managedResourceGroupName")] + public string ManagedResourceGroupName { get; set; } + } + + [JsonSerializable(typeof(SamActionManagedIdentityApiResponse))] + public partial class SamActionManagedIdentityApiResponseContext : JsonSerializerContext + { } + +} diff --git a/src/Tes.ApiClients/TerraSamApiClient.cs b/src/Tes.ApiClients/TerraSamApiClient.cs new file mode 100644 index 000000000..254ba5122 --- /dev/null +++ b/src/Tes.ApiClients/TerraSamApiClient.cs @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Core; +using CommonUtilities; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Logging; +using Tes.ApiClients.Models.Terra; + +namespace Tes.ApiClients +{ + /// + /// Terra Sam api client + /// Sam manages authorization and IAM functionality + /// + public class TerraSamApiClient : TerraApiClient + { + private const string SamApiSegments = @"/api/azure/v1"; + + private static readonly IMemoryCache SharedMemoryCache = new MemoryCache(new MemoryCacheOptions()); + + /// + /// Constructor of TerraSamApiClient + /// + /// Sam API host + /// + /// + /// + /// + public TerraSamApiClient(string apiUrl, TokenCredential tokenCredential, CachingRetryPolicyBuilder cachingRetryHandler, + AzureEnvironmentConfig azureCloudIdentityConfig, ILogger logger) : base(apiUrl, tokenCredential, cachingRetryHandler, azureCloudIdentityConfig, logger) + { } + + public static TerraSamApiClient CreateTerraSamApiClient(string apiUrl, TokenCredential tokenCredential, AzureEnvironmentConfig azureCloudIdentityConfig) + { + return CreateTerraApiClient(apiUrl, SharedMemoryCache, tokenCredential, azureCloudIdentityConfig); + } + + /// + /// Protected parameter-less constructor + /// + protected TerraSamApiClient() { } + + public virtual async Task GetActionManagedIdentityForACRPullAsync(Guid resourceId, TimeSpan cacheTTL, CancellationToken cancellationToken) + { + return await GetActionManagedIdentityAsync("private_azure_container_registry", resourceId, "pull_image", cacheTTL, cancellationToken); + } + + private async Task GetActionManagedIdentityAsync(string resourceType, Guid resourceId, string action, TimeSpan cacheTTL, CancellationToken cancellationToken) + { + ArgumentNullException.ThrowIfNull(resourceId); + + var url = GetSamActionManagedIdentityUrl(resourceType, resourceId, action); + + Logger.LogInformation(@"Fetching action managed identity from Sam for {resourceId}", resourceId); + + try + { + return await HttpGetRequestWithExpirableCachingAndRetryPolicyAsync(url, + SamActionManagedIdentityApiResponseContext.Default.SamActionManagedIdentityApiResponse, cacheTTL, cancellationToken, setAuthorizationHeader: true); + } + catch (HttpRequestException e) + { + // Sam will return a 404 if there is no action identity that matches the query, + // or if we don't have access to it. + if (e.StatusCode == System.Net.HttpStatusCode.NotFound) + { + return null; + } + else + { + throw; + } + + } + } + + public virtual Uri GetSamActionManagedIdentityUrl(string resourceType, Guid resourceId, string action) + { + var apiRequestUrl = $"{ApiUrl.TrimEnd('/')}{SamApiSegments}/actionManagedIdentity/{resourceType}/{resourceId}/{action}"; + + var uriBuilder = new UriBuilder(apiRequestUrl); + + return uriBuilder.Uri; + } + } +} diff --git a/src/Tes.Runner.Test/Docker/ContainerRegistryAuthorizationManagerTests.cs b/src/Tes.Runner.Test/Docker/ContainerRegistryAuthorizationManagerTests.cs index 2b9ff1d13..b49474202 100644 --- a/src/Tes.Runner.Test/Docker/ContainerRegistryAuthorizationManagerTests.cs +++ b/src/Tes.Runner.Test/Docker/ContainerRegistryAuthorizationManagerTests.cs @@ -32,7 +32,7 @@ public void SetUp() mockCredentialsManager = new Mock(); mockCredentials = new Mock(); - mockCredentialsManager.Setup(c => c.GetTokenCredential(It.IsAny(), It.IsAny())) + mockCredentialsManager.Setup(c => c.GetAcrPullTokenCredential(It.IsAny(), It.IsAny())) .Returns(mockCredentials.Object); diff --git a/src/Tes.Runner.Test/Docker/DockerExecutorTests.cs b/src/Tes.Runner.Test/Docker/DockerExecutorTests.cs index 7259c395a..58d5a3b1f 100644 --- a/src/Tes.Runner.Test/Docker/DockerExecutorTests.cs +++ b/src/Tes.Runner.Test/Docker/DockerExecutorTests.cs @@ -35,7 +35,7 @@ public void SetUp() dockerClientMock.Setup(d => d.Volumes).Returns(dockerVolumeMock.Object); dockerClient = dockerClientMock.Object; var credentialsManager = new Mock(); - credentialsManager.Setup(m => m.GetTokenCredential(It.IsAny(), It.IsAny())) + credentialsManager.Setup(m => m.GetAcrPullTokenCredential(It.IsAny(), It.IsAny())) .Throws(new IdentityUnavailableException()); containerRegistryAuthorizationManager = new(credentialsManager.Object); } diff --git a/src/Tes.Runner/Authentication/CredentialsManager.cs b/src/Tes.Runner/Authentication/CredentialsManager.cs index d1c5bf98f..31c06c9f7 100644 --- a/src/Tes.Runner/Authentication/CredentialsManager.cs +++ b/src/Tes.Runner/Authentication/CredentialsManager.cs @@ -38,9 +38,25 @@ private TimeSpan SleepDurationHandler(int attempt) public virtual TokenCredential GetTokenCredential(RuntimeOptions runtimeOptions, string? tokenScope = default) { + return GetTokenCredential(runtimeOptions, runtimeOptions.NodeManagedIdentityResourceId, tokenScope); + } + + public virtual TokenCredential GetAcrPullTokenCredential(RuntimeOptions runtimeOptions, string? tokenScope = default) + { + var managedIdentity = runtimeOptions.NodeManagedIdentityResourceId; + if (!string.IsNullOrWhiteSpace(runtimeOptions.AcrPullManagedIdentityResourceId)) + { + managedIdentity = runtimeOptions.AcrPullManagedIdentityResourceId; + } + return GetTokenCredential(runtimeOptions, managedIdentity, tokenScope); + } + + public virtual TokenCredential GetTokenCredential(RuntimeOptions runtimeOptions, string? managedIdentityResourceId, string? tokenScope = default) + { + tokenScope ??= runtimeOptions.AzureEnvironmentConfig!.TokenScope!; try { - return retryPolicy.Execute(() => GetTokenCredentialImpl(runtimeOptions, tokenScope)); + return retryPolicy.Execute(() => GetTokenCredentialImpl(managedIdentityResourceId, tokenScope, runtimeOptions.AzureEnvironmentConfig!.AzureAuthorityHostUrl!)); } catch { @@ -48,22 +64,20 @@ public virtual TokenCredential GetTokenCredential(RuntimeOptions runtimeOptions, } } - private TokenCredential GetTokenCredentialImpl(RuntimeOptions runtimeOptions, string? tokenScope) + private TokenCredential GetTokenCredentialImpl(string? managedIdentityResourceId, string tokenScope, string azureAuthorityHost) { - tokenScope ??= runtimeOptions.AzureEnvironmentConfig!.TokenScope!; - try { TokenCredential tokenCredential; - Uri authorityHost = new(runtimeOptions.AzureEnvironmentConfig!.AzureAuthorityHostUrl!); + Uri authorityHost = new(azureAuthorityHost); - if (!string.IsNullOrWhiteSpace(runtimeOptions.NodeManagedIdentityResourceId)) + if (!string.IsNullOrWhiteSpace(managedIdentityResourceId)) { - logger.LogInformation("Token credentials with Managed Identity and resource ID: {NodeManagedIdentityResourceId}", runtimeOptions.NodeManagedIdentityResourceId); + logger.LogInformation("Token credentials with Managed Identity and resource ID: {NodeManagedIdentityResourceId}", managedIdentityResourceId); var tokenCredentialOptions = new TokenCredentialOptions { AuthorityHost = authorityHost }; tokenCredential = new ManagedIdentityCredential( - new ResourceIdentifier(runtimeOptions.NodeManagedIdentityResourceId), + new ResourceIdentifier(managedIdentityResourceId), tokenCredentialOptions); } else diff --git a/src/Tes.Runner/Docker/ContainerRegistryAuthorizationManager.cs b/src/Tes.Runner/Docker/ContainerRegistryAuthorizationManager.cs index 5a431753d..1daa9316c 100644 --- a/src/Tes.Runner/Docker/ContainerRegistryAuthorizationManager.cs +++ b/src/Tes.Runner/Docker/ContainerRegistryAuthorizationManager.cs @@ -3,8 +3,11 @@ using Azure.Containers.ContainerRegistry; using Azure.Core; +using CommonUtilities; +using CommonUtilities.AzureCloud; using Docker.DotNet.Models; using Microsoft.Extensions.Logging; +using Tes.ApiClients; using Tes.Runner.Authentication; using Tes.Runner.Models; using Tes.Runner.Transfer; @@ -79,7 +82,7 @@ public ContainerRegistryContentClient CreateContainerRegistryContentClientWithAc // Use a pipeline policy to get access to the ACR access token we will need to pass to Docker. var clientOptions = new ContainerRegistryClientOptions(); clientOptions.AddPolicy(new AcquireDockerAuthTokenPipelinePolicy(onCapture), HttpPipelinePosition.PerCall); - return new ContainerRegistryContentClient(endpoint, repositoryName, tokenCredentialsManager.GetTokenCredential(runtimeOptions), clientOptions); + return new ContainerRegistryContentClient(endpoint, repositoryName, tokenCredentialsManager.GetAcrPullTokenCredential(runtimeOptions), clientOptions); } private sealed class AcquireDockerAuthTokenPipelinePolicy : Azure.Core.Pipeline.HttpPipelinePolicy diff --git a/src/TesApi.Tests/BatchSchedulerTests.cs b/src/TesApi.Tests/BatchSchedulerTests.cs index 30280e3da..9bb26b226 100644 --- a/src/TesApi.Tests/BatchSchedulerTests.cs +++ b/src/TesApi.Tests/BatchSchedulerTests.cs @@ -634,7 +634,7 @@ public async Task BatchJobContainsExpectedBatchPoolInformation() GuardAssertsWithTesTask(tesTask, () => { - Assert.AreEqual("TES-hostname-edicated1-rpsd645merzfkqmdnj7pkqrase2ancnh-", tesTask.PoolId[0..^8]); + Assert.AreEqual("TES-hostname-edicated1-obkfufnroslrzwlitqbrmjeowu7iuhfm-", tesTask.PoolId[0..^8]); Assert.AreEqual("VmSizeDedicated1", pool.VmSize); Assert.IsTrue(((BatchScheduler)batchScheduler).TryGetPool(tesTask.PoolId, out _)); }); @@ -658,8 +658,8 @@ public async Task BatchJobContainsExpectedManualPoolInformation() { Assert.AreEqual("VmSizeDedicated1", poolSpec.VmSize); Assert.IsTrue(poolSpec.ScaleSettings.AutoScale.Formula.Contains("TargetDedicated")); - Assert.AreEqual(1, poolSpec.Identity.UserAssignedIdentities.Count); - Assert.AreEqual(identity, poolSpec.Identity.UserAssignedIdentities.Keys.First()); + Assert.AreEqual(2, poolSpec.Identity.UserAssignedIdentities.Count); + Assert.AreEqual(identity, poolSpec.Identity.UserAssignedIdentities.Keys.Skip(1).First()); }); } @@ -1316,6 +1316,9 @@ private static Action> GetMockBatchPoolManager(AzureProx private static Action> GetMockAzureProxy(AzureProxyReturnValues azureProxyReturnValues) => azureProxy => { + azureProxy.Setup(a => a.GetManagedIdentityInBatchAccountResourceGroup(It.IsAny())) + .Returns(name => $"/subscriptions/defaultsubscription/resourceGroups/defaultresourcegroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/{name}"); + azureProxy.Setup(a => a.BlobExistsAsync(It.IsAny(), It.IsAny())) .ReturnsAsync(true); @@ -1363,23 +1366,21 @@ private static Action> GetMockAzureProxy(AzureProxyReturnValue private static Func> GetMockConfig() => new(() => - { - var config = Enumerable.Empty<(string Key, string Value)>() - .Append(("Storage:DefaultAccountName", "defaultstorageaccount")) - .Append(("BatchScheduling:Prefix", "hostname")) - .Append(("BatchImageGen1:Offer", "ubuntu-server-container")) - .Append(("BatchImageGen1:Publisher", "microsoft-azure-batch")) - .Append(("BatchImageGen1:Sku", "20-04-lts")) - .Append(("BatchImageGen1:Version", "latest")) - .Append(("BatchImageGen1:NodeAgentSkuId", "batch.node.ubuntu 20.04")) - .Append(("BatchImageGen2:Offer", "ubuntu-hpc")) - .Append(("BatchImageGen2:Publisher", "microsoft-dsvm")) - .Append(("BatchImageGen2:Sku", "2004")) - .Append(("BatchImageGen2:Version", "latest")) - .Append(("BatchImageGen2:NodeAgentSkuId", "batch.node.ubuntu 20.04")); - - return config; - }); + [ + ("Storage:DefaultAccountName", "defaultstorageaccount"), + ("BatchNodes:GlobalManagedIdentity", "/subscriptions/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/resourceGroups/SomeResourceGroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/GlobalManagedIdentity"), + ("BatchScheduling:Prefix", "hostname"), + ("BatchImageGen1:Offer", "ubuntu-server-container"), + ("BatchImageGen1:Publisher", "microsoft-azure-batch"), + ("BatchImageGen1:Sku", "20-04-lts"), + ("BatchImageGen1:Version", "latest"), + ("BatchImageGen1:NodeAgentSkuId", "batch.node.ubuntu 20.04"), + ("BatchImageGen2:Offer", "ubuntu-hpc"), + ("BatchImageGen2:Publisher", "microsoft-dsvm"), + ("BatchImageGen2:Sku", "2004"), + ("BatchImageGen2:Version", "latest"), + ("BatchImageGen2:NodeAgentSkuId", "batch.node.ubuntu 20.04"), + ]); private static IEnumerable GetFilesToDownload(Mock azureProxy) { diff --git a/src/TesApi.Tests/Runner/TaskToNodeTaskConverterTests.cs b/src/TesApi.Tests/Runner/TaskToNodeTaskConverterTests.cs index 9e3366442..db5baf5e8 100644 --- a/src/TesApi.Tests/Runner/TaskToNodeTaskConverterTests.cs +++ b/src/TesApi.Tests/Runner/TaskToNodeTaskConverterTests.cs @@ -33,7 +33,6 @@ public class TaskToNodeTaskConverterTests private readonly TesTask tesTask = GetTestTesTask(); private TerraOptions terraOptions; private StorageOptions storageOptions; - private BatchAccountOptions batchAccountOptions; private const string SasToken = "sv=2019-12-12&ss=bfqt&srt=sco&spr=https&st=2023-09-27T17%3A32%3A57Z&se=2023-09-28T17%3A32%3A57Z&sp=rwdlacupx&sig=SIGNATURE"; @@ -55,10 +54,14 @@ public class TaskToNodeTaskConverterTests [TestInitialize] public void SetUp() { - terraOptions = new TerraOptions(); - storageOptions = new StorageOptions() { ExternalStorageContainers = ExternalStorageContainerWithSas }; - batchAccountOptions = new BatchAccountOptions() { SubscriptionId = SubscriptionId, ResourceGroup = ResourceGroup }; - storageAccessProviderMock = new Mock(); + terraOptions = new(); + storageOptions = new() { ExternalStorageContainers = ExternalStorageContainerWithSas }; + + Mock azureProxyMock = new(); + azureProxyMock.Setup(x => x.GetManagedIdentityInBatchAccountResourceGroup(It.IsAny())) + .Returns(name => $"/subscriptions/{SubscriptionId}/resourceGroups/{ResourceGroup}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/{name}"); + + storageAccessProviderMock = new(); storageAccessProviderMock.Setup(x => x.GetInternalTesTaskBlobUrlAsync(It.IsAny(), It.IsAny(), It.IsAny())) .ReturnsAsync(InternalBlobUrlWithSas); @@ -76,8 +79,8 @@ public void SetUp() .Returns(Task.FromResult>([])); var azureCloudIdentityConfig = AzureCloudConfig.FromKnownCloudNameAsync().Result.AzureEnvironmentConfig; - taskToNodeTaskConverter = new TaskToNodeTaskConverter(Options.Create(terraOptions), storageAccessProviderMock.Object, - Options.Create(storageOptions), Options.Create(batchAccountOptions), azureCloudIdentityConfig, new NullLogger()); + taskToNodeTaskConverter = new TaskToNodeTaskConverter(Options.Create(terraOptions), Options.Create(storageOptions), + storageAccessProviderMock.Object, azureProxyMock.Object, azureCloudIdentityConfig, new NullLogger()); } @@ -117,6 +120,34 @@ public void GetNodeManagedIdentityResourceId_ResourceIsProvided_ReturnsExpectedR Assert.AreEqual(expectedResourceId, resourceId); } + [DataTestMethod] + [DataRow("myIdentity", $@"/subscriptions/{SubscriptionId}/resourceGroups/{ResourceGroup}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/myIdentity", false)] + [DataRow($@"/subscriptions/{SubscriptionId}/resourcegroups/{ResourceGroup}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/myIdentity", $@"/subscriptions/{SubscriptionId}/resourcegroups/{ResourceGroup}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/myIdentity", false)] + [DataRow($@"/subscriptions/{SubscriptionId}/resourceGroups/{ResourceGroup}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/myIdentity", $@"/subscriptions/{SubscriptionId}/resourceGroups/{ResourceGroup}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/myIdentity", false)] + [DataRow("", null, true)] + [DataRow(null, null, true)] + public void GetNodeManagedIdentityResourceId_NoGlobalManagedIdentity_ReturnsExpectedResult(string workflowIdentity, string expectedResourceId, bool exceptionExpected) + { + tesTask.Resources = new TesResources() + { + BackendParameters = new Dictionary() + { + {TesResources.SupportedBackendParameters.workflow_execution_identity.ToString(), workflowIdentity} + } + }; + + try + { + var resourceId = taskToNodeTaskConverter.GetNodeManagedIdentityResourceId(string.Empty, tesTask); + Assert.AreEqual(exceptionExpected, false); + Assert.AreEqual(expectedResourceId, resourceId); + } + catch (TesException) + { + Assert.AreEqual(exceptionExpected, true); + } + } + [TestMethod] public async Task ToNodeTaskAsync_TesTask_OutputsContainLogsAndMetrics() diff --git a/src/TesApi.Tests/StartupTests.cs b/src/TesApi.Tests/StartupTests.cs index 9ab1c9a98..1ba6cb562 100644 --- a/src/TesApi.Tests/StartupTests.cs +++ b/src/TesApi.Tests/StartupTests.cs @@ -80,6 +80,8 @@ private void ConfigureTerraOptions() options.WorkspaceStorageAccountName = TerraApiStubData.WorkspaceAccountName; options.WorkspaceStorageContainerName = TerraApiStubData.WorkspaceStorageContainerName; options.WorkspaceStorageContainerResourceId = terraApiStubData.ContainerResourceId.ToString(); + options.SamApiHost = TerraApiStubData.SamApiHost; + options.SamResourceIdForAcrPull = terraApiStubData.AcrPullIdentitySamResourceId.ToString(); }); } @@ -92,10 +94,10 @@ public void ConfigureServices_TerraOptionsAreConfigured_TerraStorageProviderIsRe var serviceProvider = services.BuildServiceProvider(); - var terraStorageProvider = serviceProvider.GetService(); + var storageProvider = serviceProvider.GetService(); - Assert.IsNotNull(terraStorageProvider); - Assert.IsInstanceOfType(terraStorageProvider, typeof(TerraStorageAccessProvider)); + Assert.IsNotNull(storageProvider); + Assert.IsInstanceOfType(storageProvider, typeof(TerraStorageAccessProvider)); } [TestMethod] @@ -105,10 +107,38 @@ public void ConfigureServices_TerraOptionsAreNotConfigured_DefaultStorageProvide var serviceProvider = services.BuildServiceProvider(); - var terraStorageProvider = serviceProvider.GetService(); + var storageProvider = serviceProvider.GetService(); - Assert.IsNotNull(terraStorageProvider); - Assert.IsInstanceOfType(terraStorageProvider, typeof(DefaultStorageAccessProvider)); + Assert.IsNotNull(storageProvider); + Assert.IsInstanceOfType(storageProvider, typeof(DefaultStorageAccessProvider)); + } + + [TestMethod] + public void ConfigureServices_TerraOptionsAreConfigured_TerraActionIdentityProviderIsResolved() + { + ConfigureTerraOptions(); + + startup.ConfigureServices(services); + + var serviceProvider = services.BuildServiceProvider(); + + var terraActionIdentityProvider = serviceProvider.GetService(); + + Assert.IsNotNull(terraActionIdentityProvider); + Assert.IsInstanceOfType(terraActionIdentityProvider, typeof(TerraActionIdentityProvider)); + } + + [TestMethod] + public void ConfigureServices_TerraOptionsAreNotConfigured_DefaultActionIdentityProviderIsResolved() + { + startup.ConfigureServices(services); + + var serviceProvider = services.BuildServiceProvider(); + + var actionIdentityProvider = serviceProvider.GetService(); + + Assert.IsNotNull(actionIdentityProvider); + Assert.IsInstanceOfType(actionIdentityProvider, typeof(DefaultActionIdentityProvider)); } [TestMethod] diff --git a/src/TesApi.Tests/TerraActionIdentityProviderTests.cs b/src/TesApi.Tests/TerraActionIdentityProviderTests.cs new file mode 100644 index 000000000..6e6fadbe2 --- /dev/null +++ b/src/TesApi.Tests/TerraActionIdentityProviderTests.cs @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Net; +using System.Net.Http; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; +using Tes.ApiClients; +using Tes.ApiClients.Models.Terra; +using TesApi.Web; +using TesApi.Web.Management.Configuration; + +namespace TesApi.Tests +{ + [TestClass] + public class TerraActionIdentityProviderTests + { + private Mock terraSamApiClientMock; + private TerraApiStubData terraApiStubData; + + private Guid notFoundGuid = Guid.NewGuid(); + private Guid errorGuid = Guid.NewGuid(); + + [TestInitialize] + public void SetUp() + { + terraApiStubData = new(); + terraSamApiClientMock = new(); + + terraSamApiClientMock + .Setup(t => t.GetActionManagedIdentityForACRPullAsync(It.Is(g => g.Equals(terraApiStubData.AcrPullIdentitySamResourceId)), It.IsAny(), It.IsAny())) + .ReturnsAsync(terraApiStubData.GetSamActionManagedIdentityApiResponse()); + terraSamApiClientMock + .Setup(t => t.GetActionManagedIdentityForACRPullAsync(It.Is(g => g.Equals(notFoundGuid)), It.IsAny(), It.IsAny())) + .ReturnsAsync((SamActionManagedIdentityApiResponse)null); + terraSamApiClientMock + .Setup(t => t.GetActionManagedIdentityForACRPullAsync(It.Is(g => g.Equals(errorGuid)), It.IsAny(), It.IsAny())) + .Throws(new HttpRequestException("Timeout!!", null, HttpStatusCode.GatewayTimeout)); + } + + private TerraActionIdentityProvider ActionIdentityProviderWithSamResourceId(Guid resourceId) + { + var optionsMock = new Mock>(); + optionsMock.Setup(o => o.Value).Returns(new TerraOptions() { SamApiHost = TerraApiStubData.SamApiHost, SamResourceIdForAcrPull = resourceId.ToString() }); + return new TerraActionIdentityProvider(terraSamApiClientMock.Object, optionsMock.Object, NullLogger.Instance); + } + + [TestMethod] + public async Task GetAcrPullActionIdentity_Success() + { + var actionIdentityProvider = ActionIdentityProviderWithSamResourceId(terraApiStubData.AcrPullIdentitySamResourceId); + var actionIdentity = await actionIdentityProvider.GetAcrPullActionIdentity(cancellationToken: System.Threading.CancellationToken.None); + + Assert.AreEqual(actionIdentity, terraApiStubData.ManagedIdentityObjectId); + } + + [TestMethod] + public async Task GetAcrPullActionIdentity_NotFound() + { + var actionIdentityProvider = ActionIdentityProviderWithSamResourceId(notFoundGuid); + var actionIdentity = await actionIdentityProvider.GetAcrPullActionIdentity(cancellationToken: System.Threading.CancellationToken.None); + + Assert.IsNull(actionIdentity); + } + + [TestMethod] + public async Task GetAcrPullActionIdentity_Error() + { + var actionIdentityProvider = ActionIdentityProviderWithSamResourceId(errorGuid); + var actionIdentity = await actionIdentityProvider.GetAcrPullActionIdentity(cancellationToken: System.Threading.CancellationToken.None); + + Assert.IsNull(actionIdentity); + } + } +} diff --git a/src/TesApi.Tests/TerraApiStubData.cs b/src/TesApi.Tests/TerraApiStubData.cs index c871f356e..7d36e2df5 100644 --- a/src/TesApi.Tests/TerraApiStubData.cs +++ b/src/TesApi.Tests/TerraApiStubData.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Text.Json; +using Tes.ApiClients.Models.Terra; using TesApi.Web.Management.Configuration; using TesApi.Web.Management.Models.Terra; @@ -13,6 +14,7 @@ public class TerraApiStubData { public const string LandingZoneApiHost = "https://landingzone.host"; public const string WsmApiHost = "https://wsm.host"; + public const string SamApiHost = "https://sam.host"; public const string ResourceGroup = "mrg-terra-dev-previ-20191228"; public const string WorkspaceAccountName = "lzaccount1"; public const string SasToken = "SASTOKENSTUB="; @@ -20,17 +22,23 @@ public class TerraApiStubData public const string WorkspaceStorageContainerName = $"sc-{WorkspaceIdValue}"; public const string WsmGetSasResponseStorageUrl = $"https://{WorkspaceAccountName}.blob.core.windows.net/{WorkspaceStorageContainerName}"; + public const string TerraPetName = "pet-2674060218359759651b0"; + public Guid TenantId { get; } = Guid.NewGuid(); public Guid LandingZoneId { get; } = Guid.NewGuid(); public Guid SubscriptionId { get; } = Guid.NewGuid(); public Guid ContainerResourceId { get; } = Guid.NewGuid(); public Guid WorkspaceId { get; } = Guid.Parse(WorkspaceIdValue); + public Guid AcrPullIdentitySamResourceId { get; } = Guid.NewGuid(); public string BatchAccountName => "lzee170c71b6cf678cfca744"; public string Region => "westus3"; public string BatchAccountId => $"/subscriptions/{SubscriptionId}/resourceGroups/{ResourceGroup}/providers/Microsoft.Batch/batchAccounts/{BatchAccountName}"; + public string ManagedIdentityObjectId => + $"/subscriptions/{SubscriptionId}/resourceGroups/{ResourceGroup}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/{TerraPetName}"; + public string PoolId => "poolId"; public Guid GetWorkspaceIdFromContainerName(string containerName) @@ -50,6 +58,10 @@ public WsmSasTokenApiResponse GetWsmSasTokenApiResponse(string blobName = null) { return JsonSerializer.Deserialize(GetWsmSasTokenApiResponseInJson(blobName)); } + public SamActionManagedIdentityApiResponse GetSamActionManagedIdentityApiResponse() + { + return JsonSerializer.Deserialize(GetSamActionManagedIdentityApiResponseInJson()); + } public TerraOptions GetTerraOptions() { @@ -58,9 +70,11 @@ public TerraOptions GetTerraOptions() WorkspaceId = WorkspaceId.ToString(), LandingZoneApiHost = LandingZoneApiHost, WsmApiHost = WsmApiHost, + SamApiHost = SamApiHost, WorkspaceStorageAccountName = WorkspaceAccountName, WorkspaceStorageContainerName = WorkspaceStorageContainerName, - WorkspaceStorageContainerResourceId = ContainerResourceId.ToString() + WorkspaceStorageContainerResourceId = ContainerResourceId.ToString(), + SamResourceIdForAcrPull = AcrPullIdentitySamResourceId.ToString() }; } @@ -306,6 +320,27 @@ public string GetResourceQuotaApiResponseInJson() }}"; } + public string GetSamActionManagedIdentityApiResponseInJson() + { + return $@"{{ + ""id"": {{ + ""resourceId"": {{ + ""resourceTypeName"": ""private_azure_container_registry"", + ""resourceId"": ""{AcrPullIdentitySamResourceId}"" + }}, + ""action"": ""pull_image"", + ""billingProfileId"": ""{AcrPullIdentitySamResourceId}"" + }}, + ""objectId"": ""{ManagedIdentityObjectId}"", + ""displayName"": ""my nice action identity"", + ""managedResourceGroupCoordinates"": {{ + ""tenantId"": ""{TenantId}"", + ""subscriptionId"": ""{SubscriptionId}"", + ""managedResourceGroupName"": ""{ResourceGroup}"" + }} +}}"; + } + public ApiCreateBatchPoolRequest GetApiCreateBatchPoolRequest() { return new ApiCreateBatchPoolRequest() diff --git a/src/TesApi.Tests/TestServices/TestServiceProvider.cs b/src/TesApi.Tests/TestServices/TestServiceProvider.cs index 9558fb38f..f44014f61 100644 --- a/src/TesApi.Tests/TestServices/TestServiceProvider.cs +++ b/src/TesApi.Tests/TestServices/TestServiceProvider.cs @@ -87,6 +87,7 @@ internal TestServiceProvider( .AddTransient>(_ => NullLogger.Instance) .AddTransient>(_ => NullLogger.Instance) .AddSingleton() + .AddTransient() .AddSingleton() .AddSingleton() .AddTransient() diff --git a/src/TesApi.Web/AzureProxy.cs b/src/TesApi.Web/AzureProxy.cs index c8a00a71c..a97a646c5 100644 --- a/src/TesApi.Web/AzureProxy.cs +++ b/src/TesApi.Web/AzureProxy.cs @@ -49,6 +49,7 @@ public partial class AzureProxy : IAzureProxy private readonly BatchProtocol.BatchServiceClient batchServiceClient; private readonly BatchClient batchClient; private readonly string location; + private readonly Func createNodeManagedIdentityResourceId; private readonly ArmEnvironment armEnvironment; /// @@ -77,6 +78,8 @@ public AzureProxy(IOptions batchAccountOptions, BatchAccoun this.credentialOptions = credentialOptions; this.logger = logger; + createNodeManagedIdentityResourceId = name => $"/subscriptions/{batchAccountInformation.SubscriptionId}/resourceGroups/{batchAccountInformation.ResourceGroupName}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/{name}"; + if (string.IsNullOrWhiteSpace(batchAccountOptions.Value.AccountName)) { throw new ArgumentException("The batch account name is missing from the the configuration.", nameof(batchAccountOptions)); @@ -528,5 +531,11 @@ public async Task EnableBatchPoolAutoScaleAsync(string poolId, bool preemptable, await batchClient.PoolOperations.EnableAutoScaleAsync(poolId, formulaFactory(preemptable, preemptable ? currentLowPriority ?? 0 : currentDedicated ?? 0), interval, cancellationToken: cancellationToken); } + + /// + public string GetManagedIdentityInBatchAccountResourceGroup(string identityName) + { + return createNodeManagedIdentityResourceId(identityName); + } } } diff --git a/src/TesApi.Web/BatchScheduler.cs b/src/TesApi.Web/BatchScheduler.cs index 607daf28e..4007b394a 100644 --- a/src/TesApi.Web/BatchScheduler.cs +++ b/src/TesApi.Web/BatchScheduler.cs @@ -17,6 +17,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Newtonsoft.Json; +using Tes.ApiClients; using Tes.Extensions; using Tes.Models; using TesApi.Web.Extensions; @@ -88,6 +89,7 @@ public partial class BatchScheduler : IBatchScheduler private readonly TaskExecutionScriptingManager taskExecutionScriptingManager; private readonly string runnerMD5; private readonly string drsHubApiHost; + private readonly IActionIdentityProvider actionIdentityProvider; private HashSet onlyLogBatchTaskStateOnce = []; @@ -109,6 +111,7 @@ public partial class BatchScheduler : IBatchScheduler /// Batch pool factory /// Service to get allowed vm sizes. /// + /// public BatchScheduler( ILogger logger, IOptions batchGen1Options, @@ -124,7 +127,8 @@ public BatchScheduler( IBatchSkuInformationProvider skuInformationProvider, IBatchPoolFactory poolFactory, IAllowedVmSizesService allowedVmSizesService, - TaskExecutionScriptingManager taskExecutionScriptingManager) + TaskExecutionScriptingManager taskExecutionScriptingManager, + IActionIdentityProvider actionIdentityProvider) { ArgumentNullException.ThrowIfNull(logger); ArgumentNullException.ThrowIfNull(azureProxy); @@ -133,6 +137,7 @@ public BatchScheduler( ArgumentNullException.ThrowIfNull(skuInformationProvider); ArgumentNullException.ThrowIfNull(poolFactory); ArgumentNullException.ThrowIfNull(taskExecutionScriptingManager); + ArgumentNullException.ThrowIfNull(actionIdentityProvider); this.logger = logger; this.azureProxy = azureProxy; @@ -152,6 +157,7 @@ public BatchScheduler( this.globalManagedIdentity = batchNodesOptions.Value.GlobalManagedIdentity; this.allowedVmSizesService = allowedVmSizesService; this.taskExecutionScriptingManager = taskExecutionScriptingManager; + this.actionIdentityProvider = actionIdentityProvider; _batchPoolFactory = poolFactory; batchPrefix = batchSchedulingOptions.Value.Prefix; logger.LogInformation("BatchPrefix: {BatchPrefix}", batchPrefix); @@ -512,7 +518,20 @@ private async Task AddBatchTaskAsync(TesTask tesTask, CancellationToken cancella if (tesTask.Resources?.ContainsBackendParameterValue(TesResources.SupportedBackendParameters.workflow_execution_identity) == true) { - identities.Add(tesTask.Resources?.GetBackendParameterValue(TesResources.SupportedBackendParameters.workflow_execution_identity)); + var workflowId = tesTask.Resources?.GetBackendParameterValue(TesResources.SupportedBackendParameters.workflow_execution_identity); + + if (!NodeTaskBuilder.IsValidManagedIdentityResourceId(workflowId)) + { + workflowId = azureProxy.GetManagedIdentityInBatchAccountResourceGroup(workflowId); + } + + identities.Add(workflowId); + } + + var acrPullIdentity = await actionIdentityProvider.GetAcrPullActionIdentity(CancellationToken.None); + if (acrPullIdentity is not null) + { + identities.Add(acrPullIdentity); } var virtualMachineInfo = await GetVmSizeAsync(tesTask, cancellationToken); @@ -531,7 +550,7 @@ private async Task AddBatchTaskAsync(TesTask tesTask, CancellationToken cancella modelPoolFactory: (id, ct) => GetPoolSpecification( name: id, displayName: displayName, - poolIdentity: GetBatchPoolIdentity(identities.ToArray()), + poolIdentity: GetBatchPoolIdentity([.. identities]), vmSize: virtualMachineInfo.VmSize, vmFamily: virtualMachineInfo.VmFamily, preemptable: virtualMachineInfo.LowPriority, @@ -543,7 +562,7 @@ private async Task AddBatchTaskAsync(TesTask tesTask, CancellationToken cancella var jobOrTaskId = $"{tesTask.Id}-{tesTask.Logs.Count}"; tesTask.PoolId = poolId; - var cloudTask = await ConvertTesTaskToBatchTaskUsingRunnerAsync(jobOrTaskId, tesTask, cancellationToken); + var cloudTask = await ConvertTesTaskToBatchTaskUsingRunnerAsync(jobOrTaskId, tesTask, acrPullIdentity, cancellationToken); logger.LogInformation($"Creating batch task for TES task {tesTask.Id}. Using VM size {virtualMachineInfo.VmSize}."); await azureProxy.AddBatchTaskAsync(tesTask.Id, cloudTask, poolId, cancellationToken); @@ -903,34 +922,33 @@ private ValueTask HandleTesTaskTransitionAsync(TesTask tesTask, CombinedBa .FirstOrDefault(m => (m.Condition is null || m.Condition(tesTask)) && (m.CurrentBatchTaskState is null || m.CurrentBatchTaskState == combinedBatchTaskInfo.BatchTaskState)) ?.ActionAsync(tesTask, combinedBatchTaskInfo, cancellationToken) ?? ValueTask.FromResult(false); - private async Task ConvertTesTaskToBatchTaskUsingRunnerAsync(string taskId, TesTask task, + private async Task ConvertTesTaskToBatchTaskUsingRunnerAsync(string taskId, TesTask task, string acrPullIdentity, CancellationToken cancellationToken) { - var nodeTaskCreationOptions = await GetNodeTaskConversionOptionsAsync(task, cancellationToken); + var nodeTaskCreationOptions = await GetNodeTaskConversionOptionsAsync(task, acrPullIdentity, cancellationToken); var assets = await taskExecutionScriptingManager.PrepareBatchScriptAsync(task, nodeTaskCreationOptions, cancellationToken); var batchRunCommand = taskExecutionScriptingManager.ParseBatchRunCommand(assets); - var cloudTask = new CloudTask(taskId, batchRunCommand) + return new(taskId, batchRunCommand) { Constraints = new(maxWallClockTime: taskMaxWallClockTime, retentionTime: TimeSpan.Zero, maxTaskRetryCount: 0), UserIdentity = new(new AutoUserSpecification(elevationLevel: ElevationLevel.Admin, scope: AutoUserScope.Pool)), EnvironmentSettings = assets.Environment?.Select(pair => new EnvironmentSetting(pair.Key, pair.Value)).ToList(), }; - - return cloudTask; } - private async Task GetNodeTaskConversionOptionsAsync(TesTask task, CancellationToken cancellationToken) + private async Task GetNodeTaskConversionOptionsAsync(TesTask task, string acrPullIdentity, CancellationToken cancellationToken) { - var nodeTaskCreationOptions = new NodeTaskConversionOptions( + return new( DefaultStorageAccountName: defaultStorageAccountName, AdditionalInputs: await GetAdditionalCromwellInputsAsync(task, cancellationToken), GlobalManagedIdentity: globalManagedIdentity, + AcrPullIdentity: acrPullIdentity, DrsHubApiHost: drsHubApiHost, - SetContentMd5OnUpload: batchNodesSetContentMd5OnUpload); - return nodeTaskCreationOptions; + SetContentMd5OnUpload: batchNodesSetContentMd5OnUpload + ); } private async Task> GetAdditionalCromwellInputsAsync(TesTask task, CancellationToken cancellationToken) @@ -1410,7 +1428,7 @@ private static Dictionary DelimitedTextToDictionary(string text, .ToDictionary(kv => kv.Key, kv => kv.Value); /// - /// Class that captures how transitions from current state to the new state, given the current Batch task state and optional condition. + /// Class that captures how transitions from current state to the new state, given the current Batch task state and optional condition. /// Transitions typically include an action that needs to run in order for the task to move to the new state. /// private class TesTaskStateTransition diff --git a/src/TesApi.Web/CachingWithRetriesAzureProxy.cs b/src/TesApi.Web/CachingWithRetriesAzureProxy.cs index 7b3e04006..69fb96fc0 100644 --- a/src/TesApi.Web/CachingWithRetriesAzureProxy.cs +++ b/src/TesApi.Web/CachingWithRetriesAzureProxy.cs @@ -162,6 +162,9 @@ async Task IAzureProxy.GetStorageAccountInfoAsync(string sto /// string IAzureProxy.GetArmRegion() => azureProxy.GetArmRegion(); + /// + string IAzureProxy.GetManagedIdentityInBatchAccountResourceGroup(string identityName) => azureProxy.GetManagedIdentityInBatchAccountResourceGroup(identityName); + /// Task IAzureProxy.GetFullAllocationStateAsync(string poolId, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAndCachingAsync( diff --git a/src/TesApi.Web/DefaultActionIdentityProvider.cs b/src/TesApi.Web/DefaultActionIdentityProvider.cs new file mode 100644 index 000000000..94354ff03 --- /dev/null +++ b/src/TesApi.Web/DefaultActionIdentityProvider.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Threading; +using System.Threading.Tasks; + +namespace TesApi.Web +{ + /// + /// A default no-op implementation for use outside Terra + /// + public class DefaultActionIdentityProvider : IActionIdentityProvider + { + /// + /// Returns null, to provide the default behavior + /// + public Task GetAcrPullActionIdentity(CancellationToken cancellationToken) + { + return Task.FromResult(null); + } + } +} diff --git a/src/TesApi.Web/IActionIdentityProvider.cs b/src/TesApi.Web/IActionIdentityProvider.cs new file mode 100644 index 000000000..e3d25546f --- /dev/null +++ b/src/TesApi.Web/IActionIdentityProvider.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Threading; +using System.Threading.Tasks; + +namespace TesApi.Web +{ + /// + /// Provides methods for interacting with "action identities," managed identities that are associated with specific actions (ex. pulling Docker images) + /// rather than individual users. At the time of writing this is a Terra-only concern, and the Terra control plane service Sam owns these identities. + /// + public interface IActionIdentityProvider + { + /// + /// Retrieves the action identity to use for pulling ACR images, if one exists + /// + /// A for controlling the lifetime of the asynchronous operation. + /// The resource id of the action identity, if one exists + public Task GetAcrPullActionIdentity(CancellationToken cancellationToken); + + } +} diff --git a/src/TesApi.Web/IAzureProxy.cs b/src/TesApi.Web/IAzureProxy.cs index 34246e5d0..75015580a 100644 --- a/src/TesApi.Web/IAzureProxy.cs +++ b/src/TesApi.Web/IAzureProxy.cs @@ -222,6 +222,13 @@ public interface IAzureProxy /// arm region string GetArmRegion(); + /// + /// Gets the managed identity in batch account resource group. + /// + /// Name of the identity. + /// Resource Id of the purported managed identity. + string GetManagedIdentityInBatchAccountResourceGroup(string identityName); + /// /// Disables AutoScale in a Batch Pool /// diff --git a/src/TesApi.Web/Management/Configuration/TerraOptions.cs b/src/TesApi.Web/Management/Configuration/TerraOptions.cs index a46b42efb..cc1086461 100644 --- a/src/TesApi.Web/Management/Configuration/TerraOptions.cs +++ b/src/TesApi.Web/Management/Configuration/TerraOptions.cs @@ -13,6 +13,7 @@ public class TerraOptions /// public const string SectionName = "Terra"; private const int DefaultSasTokenExpirationInSeconds = 60 * 24 * 3; // 3 days + private const int DefaultSamActionIdentityCacheTTLMinutes = 5; /// /// Landing zone id containing the Tes back-end resources @@ -20,7 +21,7 @@ public class TerraOptions public string LandingZoneId { get; set; } /// - /// Landing zone api host. + /// Landing zone api host. /// public string LandingZoneApiHost { get; set; } @@ -29,6 +30,21 @@ public class TerraOptions /// public string WsmApiHost { get; set; } + /// + /// Sam api host. + /// + public string SamApiHost { get; set; } + + /// + /// Id of the Sam resource associated with the ACR pull identity + /// + public string SamResourceIdForAcrPull { get; set; } + + /// + /// Amount of time that cached action identities should live before we ask Sam for them again + /// + public int SamActionIdentityCacheTTLMinutes { get; set; } = DefaultSamActionIdentityCacheTTLMinutes; + /// /// Workspace storage container resource id /// diff --git a/src/TesApi.Web/Runner/NodeTaskBuilder.cs b/src/TesApi.Web/Runner/NodeTaskBuilder.cs index 883043e45..785f60065 100644 --- a/src/TesApi.Web/Runner/NodeTaskBuilder.cs +++ b/src/TesApi.Web/Runner/NodeTaskBuilder.cs @@ -175,7 +175,7 @@ public NodeTaskBuilder WithContainerCommands(List commands) } /// - /// + /// /// /// /// @@ -327,6 +327,28 @@ public NodeTaskBuilder WithResourceIdManagedIdentity(string resourceId) return this; } + /// + /// Sets the managed identity to be used for ACR pulls for the node task. If the resource ID is empty or null, the property won't be set. + /// + /// A valid managed identity resource ID + /// + public NodeTaskBuilder WithAcrPullResourceIdManagedIdentity(string resourceId) + { + if (string.IsNullOrEmpty(resourceId)) + { + return this; + } + + if (!IsValidManagedIdentityResourceId(resourceId)) + { + throw new ArgumentException("Invalid resource ID. The ID must be a valid Azure resource ID.", nameof(resourceId)); + } + + nodeTask.RuntimeOptions ??= new RuntimeOptions(); + nodeTask.RuntimeOptions.AcrPullManagedIdentityResourceId = resourceId; + return this; + } + /// /// (Optional) sets the azure authority host for the node task. If not set, the default Azure Public cloud is used. /// @@ -345,7 +367,7 @@ public NodeTaskBuilder WithAzureCloudIdentityConfig(AzureEnvironmentConfig azure } /// - /// Returns true of the value provided is a valid resource id for a managed identity. + /// Returns true of the value provided is a valid resource id for a managed identity. /// /// /// @@ -355,7 +377,7 @@ public static bool IsValidManagedIdentityResourceId(string resourceId) { return false; } - //Ignore the case because constant segments could be lower case, pascal case or camel case. + //Ignore the case because constant segments could be lower case, pascal case or camel case. // e.g. /resourcegroup/ or /resourceGroup/ return Regex.IsMatch(resourceId, ManagedIdentityResourceIdPattern, RegexOptions.IgnoreCase); } diff --git a/src/TesApi.Web/Runner/TaskExecutionScriptingManager.cs b/src/TesApi.Web/Runner/TaskExecutionScriptingManager.cs index 1e54448fc..b04170a7d 100644 --- a/src/TesApi.Web/Runner/TaskExecutionScriptingManager.cs +++ b/src/TesApi.Web/Runner/TaskExecutionScriptingManager.cs @@ -56,7 +56,7 @@ public TaskExecutionScriptingManager(IStorageAccessProvider storageAccessProvide } /// - /// Prepares the runtime scripting assets required for the execution of a TES task in a Batch node using the TES runner. + /// Prepares the runtime scripting assets required for the execution of a TES task in a Batch node using the TES runner. /// /// /// @@ -72,7 +72,7 @@ public async Task PrepareBatchScriptAsync(TesTask tesTask List> environment = [new(nameof(NodeTaskResolverOptions), JsonConvert.SerializeObject( - taskToNodeConverter.ToNodeTaskResolverOptions(nodeTaskConversionOptions), + taskToNodeConverter.ToNodeTaskResolverOptions(tesTask, nodeTaskConversionOptions), DefaultSerializerSettings))]; return new BatchScriptAssetsInfo(nodeTaskUrl, environment.ToDictionary().AsReadOnly()); diff --git a/src/TesApi.Web/Runner/TaskToNodeTaskConverter.cs b/src/TesApi.Web/Runner/TaskToNodeTaskConverter.cs index 6fab774d4..35dc92178 100644 --- a/src/TesApi.Web/Runner/TaskToNodeTaskConverter.cs +++ b/src/TesApi.Web/Runner/TaskToNodeTaskConverter.cs @@ -43,31 +43,31 @@ public class TaskToNodeTaskConverter private readonly TerraOptions terraOptions; private readonly ILogger logger; private readonly IList externalStorageContainers; - private readonly BatchAccountOptions batchAccountOptions; + private readonly IAzureProxy azureProxy; private readonly AzureEnvironmentConfig azureCloudIdentityConfig; /// /// Constructor of TaskToNodeTaskConverter /// /// - /// /// - /// + /// + /// /// /// - public TaskToNodeTaskConverter(IOptions terraOptions, IStorageAccessProvider storageAccessProvider, IOptions storageOptions, IOptions batchAccountOptions, AzureEnvironmentConfig azureCloudIdentityConfig, ILogger logger) + public TaskToNodeTaskConverter(IOptions terraOptions, IOptions storageOptions, IStorageAccessProvider storageAccessProvider, IAzureProxy azureProxy, AzureEnvironmentConfig azureCloudIdentityConfig, ILogger logger) { ArgumentNullException.ThrowIfNull(terraOptions); ArgumentNullException.ThrowIfNull(storageOptions); ArgumentNullException.ThrowIfNull(storageAccessProvider); - ArgumentNullException.ThrowIfNull(batchAccountOptions); + ArgumentNullException.ThrowIfNull(azureProxy); ArgumentNullException.ThrowIfNull(azureCloudIdentityConfig); ArgumentNullException.ThrowIfNull(logger); this.terraOptions = terraOptions.Value; this.logger = logger; this.storageAccessProvider = storageAccessProvider; - this.batchAccountOptions = batchAccountOptions.Value; + this.azureProxy = azureProxy; this.azureCloudIdentityConfig = azureCloudIdentityConfig; externalStorageContainers = StorageUrlUtils.GetExternalStorageContainerInfos(storageOptions.Value); } @@ -80,16 +80,17 @@ protected TaskToNodeTaskConverter() { } /// /// Generates . /// + /// The TES task. /// The node task conversion options. /// Environment required for runner to retrieve blobs from storage. - public virtual NodeTaskResolverOptions ToNodeTaskResolverOptions(NodeTaskConversionOptions nodeTaskConversionOptions) + public virtual NodeTaskResolverOptions ToNodeTaskResolverOptions(TesTask task, NodeTaskConversionOptions nodeTaskConversionOptions) { try { var builder = new NodeTaskBuilder(); builder.WithAzureCloudIdentityConfig(azureCloudIdentityConfig) .WithStorageEventSink(storageAccessProvider.GetInternalTesBlobUrlWithoutSasToken(blobPath: string.Empty)) - .WithResourceIdManagedIdentity(nodeTaskConversionOptions.GlobalManagedIdentity); + .WithResourceIdManagedIdentity(GetNodeManagedIdentityResourceId(nodeTaskConversionOptions.GlobalManagedIdentity, task)); if (terraOptions is not null && !string.IsNullOrEmpty(terraOptions.WsmApiHost)) { @@ -134,6 +135,7 @@ public virtual async Task ToNodeTaskAsync(TesTask task, NodeTaskConver builder.WithId(task.Id) .WithAzureCloudIdentityConfig(azureCloudIdentityConfig) .WithResourceIdManagedIdentity(GetNodeManagedIdentityResourceId(task, nodeTaskConversionOptions.GlobalManagedIdentity)) + .WithAcrPullResourceIdManagedIdentity(nodeTaskConversionOptions.AcrPullIdentity) .WithWorkflowId(task.WorkflowId) .WithContainerCommands(executor.Command) .WithContainerImage(executor.Image) @@ -321,7 +323,7 @@ private async Task> PrepareInputsForMappingAsync(TesTask tesTask, /// /// This returns the node managed identity resource id from the task if it is set, otherwise it returns the global managed identity. - /// If the value in the workflow identity is not a full resource id, it is assumed to be the name. In this case, the resource id is constructed from the name. + /// If the value in the workflow identity is not a full resource id, it is assumed to be the name. In this case, the resource id is constructed from the name. /// /// /// @@ -332,8 +334,13 @@ public string GetNodeManagedIdentityResourceId(TesTask task, string globalManage task.Resources?.GetBackendParameterValue(TesResources.SupportedBackendParameters .workflow_execution_identity); - if (string.IsNullOrEmpty(workflowId)) + if (string.IsNullOrWhiteSpace(workflowId)) { + if (!NodeTaskBuilder.IsValidManagedIdentityResourceId(globalManagedIdentity)) + { + throw new TesException("NoManagedIdentityForRunner", "Neither the TES server nor the task provided an Azure User Managed Identity for the task runner. Please check your configuration."); + } + return globalManagedIdentity; } @@ -342,7 +349,38 @@ public string GetNodeManagedIdentityResourceId(TesTask task, string globalManage return workflowId; } - return $"/subscriptions/{batchAccountOptions.SubscriptionId}/resourceGroups/{batchAccountOptions.ResourceGroup}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/{workflowId}"; + return azureProxy.GetManagedIdentityInBatchAccountResourceGroup(workflowId); + } + + /// + /// This returns the global managed identity if it is set, otherwise it returns the node managed identity resource id from the task. + /// If the value in the workflow identity is not a full resource id, it is assumed to be the name. In this case, the resource id is constructed from the name. + /// + /// + /// + /// + public string GetNodeManagedIdentityResourceId(string globalManagedIdentity, TesTask task) + { + if (NodeTaskBuilder.IsValidManagedIdentityResourceId(globalManagedIdentity)) + { + return globalManagedIdentity; + } + + var workflowId = + task.Resources?.GetBackendParameterValue(TesResources.SupportedBackendParameters + .workflow_execution_identity); + + if (string.IsNullOrWhiteSpace(workflowId)) + { + throw new TesException("NoManagedIdentityForRunner", "Neither the TES server nor the task provided an Azure User Managed Identity for the task runner. Please check your configuration."); + } + + if (NodeTaskBuilder.IsValidManagedIdentityResourceId(workflowId)) + { + return workflowId; + } + + return azureProxy.GetManagedIdentityInBatchAccountResourceGroup(workflowId); } @@ -400,7 +438,7 @@ private string AppendSasTokenIfExternalAccount(string url) private TesInput PrepareLocalFileInput(TesInput input, string defaultStorageAccountName) { //When Cromwell runs in local mode with a Blob FUSE drive, the URL property may contain an absolute path. - //The path must be converted to a URL. For Terra this scenario doesn't apply. + //The path must be converted to a URL. For Terra this scenario doesn't apply. if (StorageUrlUtils.IsLocalAbsolutePath(input.Url)) { var convertedUrl = StorageUrlUtils.ConvertLocalPathOrCromwellLocalPathToUrl(input.Url, defaultStorageAccountName); @@ -552,8 +590,9 @@ private static string AppendParentDirectoryIfSet(string inputPath, string pathPa /// /// /// + /// /// /// public record NodeTaskConversionOptions(IList AdditionalInputs = default, string DefaultStorageAccountName = default, - string GlobalManagedIdentity = default, string DrsHubApiHost = default, bool SetContentMd5OnUpload = false); + string GlobalManagedIdentity = default, string AcrPullIdentity = default, string DrsHubApiHost = default, bool SetContentMd5OnUpload = false); } diff --git a/src/TesApi.Web/Startup.cs b/src/TesApi.Web/Startup.cs index 901a7e449..be52af68f 100644 --- a/src/TesApi.Web/Startup.cs +++ b/src/TesApi.Web/Startup.cs @@ -114,6 +114,7 @@ public void ConfigureServices(IServiceCollection services) .AddAutoMapper(typeof(MappingProfilePoolToWsmRequest)) .AddSingleton() .AddSingleton(s => s.GetRequiredService()) // Return the already declared retry policy builder + .AddSingleton(CreateActionIdentityProvider) .AddSingleton() .AddSingleton() .AddSingleton() @@ -388,6 +389,25 @@ TerraWsmApiClient CreateTerraApiClient(IServiceProvider services) throw new InvalidOperationException("Terra WSM API Host is not configured."); } + IActionIdentityProvider CreateActionIdentityProvider(IServiceProvider services) + { + logger.LogInformation("Attempting to create an ActionIdentityProvider"); + + if (TerraOptionsAreConfigured(services)) + { + var options = services.GetRequiredService>(); + + ValidateRequiredOptionsForTerraActionIdentities(options.Value); + + var samClient = ActivatorUtilities.CreateInstance(services, options.Value.SamApiHost); + + logger.LogInformation("Creating TerraActionIdentityProvider"); + return ActivatorUtilities.CreateInstance(services, samClient); + } + + return ActivatorUtilities.CreateInstance(services); + } + static void ValidateRequiredOptionsForTerraStorageProvider(TerraOptions terraOptions) { ArgumentException.ThrowIfNullOrEmpty(terraOptions.WorkspaceId, nameof(terraOptions.WorkspaceId)); @@ -397,6 +417,12 @@ static void ValidateRequiredOptionsForTerraStorageProvider(TerraOptions terraOpt ArgumentException.ThrowIfNullOrEmpty(terraOptions.WsmApiHost, nameof(terraOptions.WsmApiHost)); } + static void ValidateRequiredOptionsForTerraActionIdentities(TerraOptions terraOptions) + { + ArgumentException.ThrowIfNullOrEmpty(terraOptions.SamApiHost, nameof(terraOptions.SamApiHost)); + ArgumentException.ThrowIfNullOrEmpty(terraOptions.SamApiHost, nameof(terraOptions.SamResourceIdForAcrPull)); + } + BatchAccountResourceInformation CreateBatchAccountResourceInformation(IServiceProvider services) { var options = services.GetRequiredService>(); diff --git a/src/TesApi.Web/TerraActionIdentityProvider.cs b/src/TesApi.Web/TerraActionIdentityProvider.cs new file mode 100644 index 000000000..1c8e248a7 --- /dev/null +++ b/src/TesApi.Web/TerraActionIdentityProvider.cs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Tes.ApiClients; +using TesApi.Web.Management.Configuration; + +namespace TesApi.Web +{ + /// + /// An ActionIdentityProvider implementation for use in Terra. Obtains action identities from Sam. + /// + public class TerraActionIdentityProvider : IActionIdentityProvider + { + private readonly Guid samResourceIdForAcrPull; + private readonly TerraSamApiClient terraSamApiClient; + private readonly TimeSpan samCacheTTL; + private readonly ILogger Logger; + + /// + /// An ActionIdentityProvider implementation for use in Terra. Obtains action identities from Sam. + /// + /// + /// + /// + public TerraActionIdentityProvider(TerraSamApiClient terraSamApiClient, IOptions terraOptions, ILogger Logger) + { + ArgumentNullException.ThrowIfNull(terraOptions); + ArgumentNullException.ThrowIfNull(terraOptions.Value.SamResourceIdForAcrPull); + this.samResourceIdForAcrPull = Guid.Parse(terraOptions.Value.SamResourceIdForAcrPull); + this.samCacheTTL = TimeSpan.FromMinutes(terraOptions.Value.SamActionIdentityCacheTTLMinutes); + this.terraSamApiClient = terraSamApiClient; + this.Logger = Logger; + } + + + /// + /// Retrieves the action identity to use for pulling ACR images, if one exists + /// + /// A for controlling the lifetime of the asynchronous operation. + /// The resource id of the action identity, if one exists. Otherwise, null. + public async Task GetAcrPullActionIdentity(CancellationToken cancellationToken) + { + try + { + var response = await terraSamApiClient.GetActionManagedIdentityForACRPullAsync(samResourceIdForAcrPull, samCacheTTL, CancellationToken.None); + if (response is null) + { + // Corresponds to no identity existing in Sam, or the user not having access to it. + Logger.LogInformation(@"Found no ACR Pull action identity in Sam for {id}", samResourceIdForAcrPull); + return null; + } + else + { + Logger.LogInformation(@"Successfully fetched ACR action identity from Sam: {ObjectId}", response.ObjectId); + return response.ObjectId; + } + } + catch (Exception e) + { + Logger.LogError(e, "Failed when trying to obtain an ACR Pull action identity from Sam"); + return null; + } + } + + } +}