diff --git a/src/CommonUtilities.Tests/ArmEnvironmentEndpointsTests.cs b/src/CommonUtilities.Tests/ArmEnvironmentEndpointsTests.cs new file mode 100644 index 000000000..4db1f88c4 --- /dev/null +++ b/src/CommonUtilities.Tests/ArmEnvironmentEndpointsTests.cs @@ -0,0 +1,193 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Frozen; +using CommonUtilities.AzureCloud; + +namespace CommonUtilities.Tests +{ + [TestClass] + public class ArmEnvironmentEndpointsTests + { + public enum Cloud + { + Public, + USGovernment, + China + } + + private struct DictionaryValueEqualityComparer : IEqualityComparer + { + readonly bool IEqualityComparer.Equals(object? x, object? y) + { + return x switch + { + IReadOnlyDictionary X => y switch + { + T Y => DictionaryValueEqualityComparer.Equals(X, Y), + _ => false, + }, + null => y is null, + _ => x.Equals(y), + }; + } + + private static bool Equals(IReadOnlyDictionary x, T y) + { + var actual = FrozenDictionary(x.Keys.Select(name => (Name: name, Property: typeof(T).GetProperty(name))).Where(e => e.Property is not null).Select(e => (e.Name, e.Property!.GetValue(y)))); + if (x.Count != actual.Count) return false; + if (x.Keys.Order(StringComparer.Ordinal).Zip(actual.Keys.Order(StringComparer.Ordinal)).Any(e => !e.First.Equals(e.Second, StringComparison.Ordinal))) return false; + + foreach (var key in x.Keys) + { + switch (x[key]) + { + case IReadOnlyDictionary expected: + if (!ArmEnvironmentEndpointsTests.Equals(expected, actual[key])) return false; + break; + + default: + if (!Equals(x[key], actual[key])) return false; + break; + } + } + + return true; + } + + readonly int IEqualityComparer.GetHashCode(object obj) => obj.GetHashCode(); + } + + private static bool Equals(IReadOnlyDictionary x, object? y) => y switch + { + AzureCloudConfig.AuthenticationDetails z => ((IEqualityComparer)new DictionaryValueEqualityComparer()).Equals(x, z), + AzureCloudConfig.EndpointSuffixes z => ((IEqualityComparer)new DictionaryValueEqualityComparer()).Equals(x, z), + _ => false, + }; + + private static FrozenDictionary FrozenDictionary(IEnumerable<(string, object?)> values) + { + var dictionary = new Dictionary(StringComparer.Ordinal); + dictionary.AddRange(values.ToDictionary()); + return dictionary.ToFrozenDictionary(); + } + + private static readonly FrozenDictionary> CloudEndpoints + = new Dictionary>() + { + { Cloud.Public, FrozenDictionary( + [ + (nameof(AzureCloudConfig.Authentication), FrozenDictionary( + [ + (nameof(AzureCloudConfig.AuthenticationDetails.LoginEndpointUrl), "https://login.microsoftonline.com"), + (nameof(AzureCloudConfig.AuthenticationDetails.Tenant), "common"), + ])), + (nameof(AzureCloudConfig.ResourceManagerUrl), "https://management.azure.com/"), + (nameof(AzureCloudConfig.ApplicationInsightsResourceUrl), "https://api.applicationinsights.io"), + (nameof(AzureCloudConfig.ApplicationInsightsTelemetryChannelResourceUrl), "https://dc.applicationinsights.azure.com/v2/track"), + (nameof(AzureCloudConfig.BatchUrl), "https://batch.core.windows.net/"), + (nameof(AzureCloudConfig.Suffixes), FrozenDictionary( + [ + (nameof(AzureCloudConfig.Suffixes.AcrLoginServerSuffix), "azurecr.io"), + (nameof(AzureCloudConfig.Suffixes.KeyVaultDnsSuffix), "vault.azure.net"), + (nameof(AzureCloudConfig.Suffixes.StorageSuffix), "core.windows.net"), + (nameof(AzureCloudConfig.Suffixes.PostgresqlServerEndpointSuffix), "postgres.database.azure.com"), + ])), + ] + )}, + + { Cloud.USGovernment, FrozenDictionary( + [ + (nameof(AzureCloudConfig.Authentication), FrozenDictionary( + [ + (nameof(AzureCloudConfig.AuthenticationDetails.LoginEndpointUrl), "https://login.microsoftonline.us"), + (nameof(AzureCloudConfig.AuthenticationDetails.Tenant), "common"), + ])), + (nameof(AzureCloudConfig.ResourceManagerUrl), "https://management.usgovcloudapi.net"), + (nameof(AzureCloudConfig.ApplicationInsightsResourceUrl), "https://api.applicationinsights.us"), + (nameof(AzureCloudConfig.ApplicationInsightsTelemetryChannelResourceUrl), "https://dc.applicationinsights.us/v2/track"), + (nameof(AzureCloudConfig.BatchUrl), "https://batch.core.usgovcloudapi.net"), + (nameof(AzureCloudConfig.Suffixes), FrozenDictionary( + [ + (nameof(AzureCloudConfig.Suffixes.AcrLoginServerSuffix), "azurecr.us"), + (nameof(AzureCloudConfig.Suffixes.KeyVaultDnsSuffix), "vault.usgovcloudapi.net"), + (nameof(AzureCloudConfig.Suffixes.StorageSuffix), "core.usgovcloudapi.net"), + (nameof(AzureCloudConfig.Suffixes.PostgresqlServerEndpointSuffix), "postgres.database.usgovcloudapi.net"), + ])), + ] + )}, + + { Cloud.China, FrozenDictionary( + [ + (nameof(AzureCloudConfig.Authentication), FrozenDictionary( + [ + (nameof(AzureCloudConfig.AuthenticationDetails.LoginEndpointUrl), "https://login.chinacloudapi.cn"), + (nameof(AzureCloudConfig.AuthenticationDetails.Tenant), "common"), + ])), + (nameof(AzureCloudConfig.ResourceManagerUrl), "https://management.chinacloudapi.cn"), + (nameof(AzureCloudConfig.ApplicationInsightsResourceUrl), "https://api.applicationinsights.azure.cn"), + (nameof(AzureCloudConfig.ApplicationInsightsTelemetryChannelResourceUrl), "https://dc.applicationinsights.azure.cn/v2/track"), + (nameof(AzureCloudConfig.BatchUrl), "https://batch.chinacloudapi.cn"), + (nameof(AzureCloudConfig.Suffixes), FrozenDictionary( + [ + (nameof(AzureCloudConfig.Suffixes.AcrLoginServerSuffix), "azurecr.cn"), + (nameof(AzureCloudConfig.Suffixes.KeyVaultDnsSuffix), "vault.azure.cn"), + (nameof(AzureCloudConfig.Suffixes.StorageSuffix), "core.chinacloudapi.cn"), + (nameof(AzureCloudConfig.Suffixes.PostgresqlServerEndpointSuffix), "postgres.database.chinacloudapi.cn"), + ])), + ] + )}, + }.ToFrozenDictionary(); + + [DataTestMethod] + [DataRow("AzureCloud", "https://management.azure.com//.default", DisplayName = "AzureCloud")] + [DataRow("AzurePublicCloud", "https://management.azure.com//.default", DisplayName = "AzurePublicCloud")] + [DataRow("AzureUSGovernment", "https://management.usgovcloudapi.net/.default", DisplayName = "AzureUSGovernment")] + [DataRow("AzureUSGovernmentCloud", "https://management.usgovcloudapi.net/.default", DisplayName = "AzureUSGovernmentCloud")] + [DataRow("AzureChinaCloud", "https://management.chinacloudapi.cn/.default", DisplayName = "AzureChinaCloud")] + public async Task FromKnownCloudNameAsync_ExpectedDefaultTokenScope(string cloud, string audience) + { + var environment = await AzureCloudConfig.FromKnownCloudNameAsync(cloudName: cloud, retryPolicyOptions: Microsoft.Extensions.Options.Options.Create(new Options.RetryPolicyOptions())); + Assert.AreEqual(audience, GetPropertyFromEnvironment(environment, nameof(AzureCloudConfig.DefaultTokenScope))); + } + + private static T? GetPropertyFromEnvironment(AzureCloudConfig environment, string property) + { + return (T?)environment.GetType().GetProperty(property)?.GetValue(environment); + } + + [DataTestMethod] + [DataRow(Cloud.Public, "AzureCloud", DisplayName = "All generally available global Azure regions")] + [DataRow(Cloud.USGovernment, "AzureUSGovernment", DisplayName = "Azure Government")] + [DataRow(Cloud.China, "AzureChinaCloud", DisplayName = "Microsoft Azure operated by 21Vianet")] + public async Task FromKnownCloudNameAsync_ExpectedValues(Cloud cloud, string cloudName) + { + var environment = await AzureCloudConfig.FromKnownCloudNameAsync(cloudName: cloudName, retryPolicyOptions: Microsoft.Extensions.Options.Options.Create(new Options.RetryPolicyOptions())); + foreach (var (property, value) in CloudEndpoints[cloud]) + { + switch (value) + { + case var x when x is string: + Assert.AreEqual(value, GetPropertyFromEnvironment(environment, property)); + break; + + case var x when x is IReadOnlyDictionary: + switch (property) + { + case nameof(AzureCloudConfig.Authentication): + Assert.AreEqual(value, GetPropertyFromEnvironment(environment, property), new DictionaryValueEqualityComparer()); + break; + + case nameof(AzureCloudConfig.Suffixes): + Assert.AreEqual(value, GetPropertyFromEnvironment(environment, property), new DictionaryValueEqualityComparer()); + break; + } + break; + + default: + throw new NotSupportedException(); + } + } + } + } +} diff --git a/src/CommonUtilities.Tests/CommonUtilities.Tests.csproj b/src/CommonUtilities.Tests/CommonUtilities.Tests.csproj index f0dff427d..276cddb39 100644 --- a/src/CommonUtilities.Tests/CommonUtilities.Tests.csproj +++ b/src/CommonUtilities.Tests/CommonUtilities.Tests.csproj @@ -10,10 +10,18 @@ - + + + + + + + + + - - + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/src/CommonUtilities/AzureCloudConfig.cs b/src/CommonUtilities/AzureCloudConfig.cs index 18ce21ad8..a801e4a7d 100644 --- a/src/CommonUtilities/AzureCloudConfig.cs +++ b/src/CommonUtilities/AzureCloudConfig.cs @@ -3,15 +3,15 @@ using System.Text.Json; using System.Text.Json.Serialization; +using Azure.Identity; using Azure.ResourceManager; -using Microsoft.Azure.Management.ResourceManager.Fluent; -using Polly; +using Microsoft.Extensions.Options; namespace CommonUtilities.AzureCloud { public class AzureCloudConfig { - private const string defaultAzureCloudMetadataUrlApiVersion = "2023-11-01"; + private const string DefaultAzureCloudMetadataUrlApiVersion = "2023-11-01"; public const string DefaultAzureCloudName = "AzureCloud"; [JsonPropertyName("portal")] @@ -65,131 +65,247 @@ public class AzureCloudConfig [JsonPropertyName("ossrDbmsResourceId")] public string? OssrDbmsResourceUrl { get; set; } + public class AuthenticationDetails + { + [JsonPropertyName("loginEndpoint")] + public string LoginEndpointUrl { get; set; } = "https://login.microsoftonline.com"; + + [JsonPropertyName("audiences")] + public List? Audiences { get; set; } + + [JsonPropertyName("tenant")] + public string? Tenant { get; set; } + + [JsonPropertyName("identityProvider")] + public string? IdentityProvider { get; set; } + } + + public class EndpointSuffixes + { + [JsonPropertyName("acrLoginServer")] + public string? AcrLoginServerSuffix { get; set; } + + [JsonPropertyName("sqlServerHostname")] + public string? SqlServerHostnameSuffix { get; set; } + + [JsonPropertyName("keyVaultDns")] + public string? KeyVaultDnsSuffix { get; set; } + + [JsonPropertyName("storage")] + public string? StorageSuffix { get; set; } + + [JsonPropertyName("storageSyncEndpointSuffix")] + public string? StorageSyncEndpointSuffix { get; set; } + + [JsonPropertyName("mhsmDns")] + public string? ManagedHsmDnsSuffix { get; set; } + + [JsonPropertyName("mysqlServerEndpoint")] + public string? MysqlServerEndpointSuffix { get; set; } + + [JsonPropertyName("postgresqlServerEndpoint")] + public string? PostgresqlServerEndpointSuffix { get; set; } + + [JsonPropertyName("mariadbServerEndpoint")] + public string? MariadbServerEndpointSuffix { get; set; } + + [JsonPropertyName("synapseAnalytics")] + public string? SynapseAnalyticsSuffix { get; set; } + } + + // Critical properties here + public string? DefaultTokenScope { get; set; } - public AzureEnvironment? AzureEnvironment { get; set; } - public ArmEnvironment ArmEnvironment { get; set; } + public ArmEnvironment? ArmEnvironment { get; set; } public string? Domain { get; set; } public AzureEnvironmentConfig? AzureEnvironmentConfig { get; set; } - public static async Task CreateAsync(string azureCloudName = DefaultAzureCloudName, string azureCloudMetadataUrlApiVersion = defaultAzureCloudMetadataUrlApiVersion) + // Helper properties + + public Uri? AuthorityHost { get; set; } + + /// + /// Azure cloud endpoints from cloud name + /// + /// Name of Azure cloud, either from IMDS or the resource manager metadata/endpoints query. + /// . + /// + /// + /// + /// Recognized cloud names are: + /// All generally available global Azure regions: AzureCloud, AzurePublicCloud + /// Azure Government: AzureUSGovernment: AzureUSGovernmentCloud + /// Microsoft Azure operated by 21Vianet: AzureChinaCloud + /// + public static Task FromKnownCloudNameAsync(string cloudName = DefaultAzureCloudName, string azureCloudMetadataUrlApiVersion = DefaultAzureCloudMetadataUrlApiVersion, IOptions? retryPolicyOptions = default) { - // It's critical that this succeeds for TES to function - // These URLs are expected to always be available - string domain; - string defaultTokenScope; - AzureEnvironment azureEnvironment; - ArmEnvironment armEnvironment; - // Names defined here: https://github.com/Azure/azure-sdk-for-net/blob/bc9f38eca0d8abbf0697dd3e3e75220553eeeafa/sdk/identity/Azure.Identity/src/AzureAuthorityHosts.cs#L11 - switch (azureCloudName.ToUpperInvariant()) + return cloudName.ToLowerInvariant() switch { - case "AZURECLOUD": - domain = "azure.com"; - // The double slash is intentional for the public cloud. - // https://github.com/Azure/azure-sdk-for-net/blob/bc9f38eca0d8abbf0697dd3e3e75220553eeeafa/sdk/identity/Azure.Identity/src/AzureAuthorityHosts.cs#L53 - defaultTokenScope = $"https://management.{domain}//.default"; - azureEnvironment = AzureEnvironment.AzureGlobalCloud; - armEnvironment = ArmEnvironment.AzurePublicCloud; - break; - case "AZUREUSGOVERNMENT": - domain = "usgovcloudapi.net"; - defaultTokenScope = $"https://management.{domain}/.default"; - azureEnvironment = AzureEnvironment.AzureUSGovernment; - armEnvironment = ArmEnvironment.AzureGovernment; - break; - case "AZURECHINACLOUD": - domain = "chinacloudapi.cn"; - defaultTokenScope = $"https://management.{domain}/.default"; - azureEnvironment = AzureEnvironment.AzureChinaCloud; - armEnvironment = ArmEnvironment.AzureChina; - break; - default: - throw new ArgumentException($"Invalid Azure cloud name: {azureCloudName}"); - } - - string azureCloudMetadataUrl = $"https://management.{domain}/metadata/endpoints?api-version={azureCloudMetadataUrlApiVersion}"; + "azurecloud" => FromMetadataEndpointsAsync(AzurePublicCloud, azureCloudMetadataUrlApiVersion, retryPolicyOptions), + "azurepubliccloud" => FromMetadataEndpointsAsync(AzurePublicCloud, azureCloudMetadataUrlApiVersion, retryPolicyOptions), + "azureusgovernmentcloud" => FromMetadataEndpointsAsync(AzureUSGovernmentCloud, azureCloudMetadataUrlApiVersion, retryPolicyOptions), + "azureusgovernment" => FromMetadataEndpointsAsync(AzureUSGovernmentCloud, azureCloudMetadataUrlApiVersion, retryPolicyOptions), + "azurechinacloud" => FromMetadataEndpointsAsync(AzureChinaCloud, azureCloudMetadataUrlApiVersion, retryPolicyOptions), + null => throw new ArgumentNullException(nameof(cloudName)), + _ => throw new ArgumentOutOfRangeException(nameof(cloudName)), + }; + } - var retryPolicy = Policy - .Handle() - .WaitAndRetryAsync(10, retryAttempt => TimeSpan.FromSeconds(30), onRetry: (exception, timespan, retryAttempt, context) => - { - Console.WriteLine($"Attempt {retryAttempt}: Retrying AzureCloudConfig creation due to error: {exception.Message}. {exception}"); - }); + private static readonly Uri AzurePublicCloud = Azure.ResourceManager.ArmEnvironment.AzurePublicCloud.Endpoint; + private static readonly Uri AzureUSGovernmentCloud = Azure.ResourceManager.ArmEnvironment.AzureGovernment.Endpoint; + private static readonly Uri AzureChinaCloud = Azure.ResourceManager.ArmEnvironment.AzureChina.Endpoint; - using var httpClient = new HttpClient(); + /// + /// Azure cloud endpoints from cloud management endpoints + /// + /// Azure cloud resource management endpoint. + /// + public static async Task FromMetadataEndpointsAsync(Uri cloudManagement, string azureCloudMetadataUrlApiVersion = DefaultAzureCloudMetadataUrlApiVersion, IOptions? retryPolicyOptions = default) + { + ArgumentNullException.ThrowIfNull(cloudManagement); + var retryPolicy = new RetryPolicyBuilder(retryPolicyOptions ?? Microsoft.Extensions.Options.Options.Create(new())).DefaultRetryHttpResponseMessagePolicyBuilder().SetOnRetryBehavior().AsyncBuildPolicy(); + HttpResponseMessage response; - return await retryPolicy.ExecuteAsync(async () => { - var httpResponse = await httpClient.GetAsync(azureCloudMetadataUrl); - httpResponse.EnsureSuccessStatusCode(); - var jsonString = await httpResponse.Content.ReadAsStringAsync(); - var config = JsonSerializer.Deserialize(jsonString, AzureCloudConfigContext.Default.AzureCloudConfig)!; - config.DefaultTokenScope = defaultTokenScope; - config.AzureEnvironment = azureEnvironment; - config.ArmEnvironment = armEnvironment; - config.Domain = domain; - - config.AzureEnvironmentConfig = new AzureEnvironmentConfig - { - AzureAuthorityHostUrl = config.Authentication?.LoginEndpointUrl, - TokenScope = config.DefaultTokenScope, - StorageUrlSuffix = config.Suffixes?.StorageSuffix, - }; - - return config; - }); - } - } - - [JsonSerializable(typeof(AzureCloudConfig))] - public partial class AzureCloudConfigContext : JsonSerializerContext - { } + using HttpClient client = new(); + response = await retryPolicy.ExecuteAsync(() => + client.SendAsync(new(HttpMethod.Get, new UriBuilder(cloudManagement) { Path = "/metadata/endpoints", Query = $"api-version={azureCloudMetadataUrlApiVersion}" }.Uri))); + } - public class AuthenticationDetails - { - [JsonPropertyName("loginEndpoint")] - public string LoginEndpointUrl { get; set; } = "https://login.microsoftonline.com"; + response.EnsureSuccessStatusCode(); - [JsonPropertyName("audiences")] - public List? Audiences { get; set; } + var jsonString = await response.Content.ReadAsStringAsync(); + var config = JsonSerializer.Deserialize(jsonString, AzureCloudConfigContext.Default.AzureCloudConfig)!; + config.ArmEnvironment = GetEnvironment(config.Name); + config.DefaultTokenScope = CreateScope(config); + config.Domain = GetDomain(config.ResourceManagerUrl); + config.AzureEnvironmentConfig = new AzureEnvironmentConfig(config.Authentication?.LoginEndpointUrl, config.DefaultTokenScope, config.Suffixes?.StorageSuffix); + config.AuthorityHost = GetAuthorityHost(config); + return config; - [JsonPropertyName("tenant")] - public string? Tenant { get; set; } + static string? CreateScope(AzureCloudConfig config) + { + if (config.ArmEnvironment is not null) + { + return config.ArmEnvironment.Value.DefaultScope; + } - [JsonPropertyName("identityProvider")] - public string? IdentityProvider { get; set; } - } + var audience = config.Authentication?.Audiences?.LastOrDefault(); - public class EndpointSuffixes - { - [JsonPropertyName("acrLoginServer")] - public string? AcrLoginServerSuffix { get; set; } + if (!string.IsNullOrWhiteSpace(audience)) + { + return audience + @"/.default"; + } - [JsonPropertyName("sqlServerHostname")] - public string? SqlServerHostnameSuffix { get; set; } + return default; + } - [JsonPropertyName("keyVaultDns")] - public string? KeyVaultDnsSuffix { get; set; } + static string? GetDomain(string? resourceManagerUrl) + { + if (!string.IsNullOrWhiteSpace(resourceManagerUrl)) + { + return new Uri(resourceManagerUrl).Host[@"management.".Length..]; + } - [JsonPropertyName("storage")] - public string? StorageSuffix { get; set; } + return default; + } - [JsonPropertyName("storageSyncEndpointSuffix")] - public string? StorageSyncEndpointSuffix { get; set; } + static Uri? GetAuthorityHost(AzureCloudConfig config) + { + return config.Name switch + { + "AzureCloud" => AzureAuthorityHosts.AzurePublicCloud, + "AzureUSGovernment" => AzureAuthorityHosts.AzureGovernment, + "AzureChinaCloud" => AzureAuthorityHosts.AzureChina, + // Environment.GetEnvironmentVariable("AZURE_AUTHORITY_HOST") + _ => GetFromConfig(), + }; - [JsonPropertyName("mhsmDns")] - public string? ManagedHsmDnsSuffix { get; set; } + Uri? GetFromConfig() + { + var authorityHost = config.Authentication?.LoginEndpointUrl; - [JsonPropertyName("mysqlServerEndpoint")] - public string? MysqlServerEndpointSuffix { get; set; } + if (!string.IsNullOrWhiteSpace(authorityHost)) + { + return new(authorityHost); + } - [JsonPropertyName("postgresqlServerEndpoint")] - public string? PostgresqlServerEndpointSuffix { get; set; } + return default; + } + } - [JsonPropertyName("mariadbServerEndpoint")] - public string? MariadbServerEndpointSuffix { get; set; } + static ArmEnvironment? GetEnvironment(string? name) + => name?.ToLowerInvariant() switch + { + "azurepubliccloud" => Azure.ResourceManager.ArmEnvironment.AzurePublicCloud, + "azurecloud" => Azure.ResourceManager.ArmEnvironment.AzurePublicCloud, + "azureusgovernmentcloud" => Azure.ResourceManager.ArmEnvironment.AzureGovernment, + "azureusgovernment" => Azure.ResourceManager.ArmEnvironment.AzureGovernment, + "azurechinacloud" => Azure.ResourceManager.ArmEnvironment.AzureChina, + _ => default, + }; + } - [JsonPropertyName("synapseAnalytics")] - public string? SynapseAnalyticsSuffix { get; set; } + //public static async Task CreateAsync(string azureCloudName = DefaultAzureCloudName, string azureCloudMetadataUrlApiVersion = DefaultAzureCloudMetadataUrlApiVersion) + //{ + // // It's critical that this succeeds for TES to function + // // These URLs are expected to always be available + // string domain; + // string defaultTokenScope; + // ArmEnvironment armEnvironment; + // // Names defined here: https://github.com/Azure/azure-sdk-for-net/blob/bc9f38eca0d8abbf0697dd3e3e75220553eeeafa/sdk/identity/Azure.Identity/src/AzureAuthorityHosts.cs#L11 + // switch (azureCloudName.ToUpperInvariant()) + // { + // case "AZURECLOUD": + // domain = "azure.com"; + // // The double slash is intentional for the public cloud. + // // https://github.com/Azure/azure-sdk-for-net/blob/bc9f38eca0d8abbf0697dd3e3e75220553eeeafa/sdk/identity/Azure.Identity/src/AzureAuthorityHosts.cs#L53 + // defaultTokenScope = $"https://management.{domain}//.default"; + // armEnvironment = Azure.ResourceManager.ArmEnvironment.AzurePublicCloud; + // break; + // case "AZUREUSGOVERNMENT": + // domain = "usgovcloudapi.net"; + // defaultTokenScope = $"https://management.{domain}/.default"; + // armEnvironment = Azure.ResourceManager.ArmEnvironment.AzureGovernment; + // break; + // case "AZURECHINACLOUD": + // domain = "chinacloudapi.cn"; + // defaultTokenScope = $"https://management.{domain}/.default"; + // armEnvironment = Azure.ResourceManager.ArmEnvironment.AzureChina; + // break; + // default: + // throw new ArgumentException($"Invalid Azure cloud name: {azureCloudName}"); + // } + + // string azureCloudMetadataUrl = $"https://management.{domain}/metadata/endpoints?api-version={azureCloudMetadataUrlApiVersion}"; + + // var retryPolicy = Policy + // .Handle() + // .WaitAndRetryAsync(10, retryAttempt => TimeSpan.FromSeconds(30), onRetry: (exception, timespan, retryAttempt, context) => + // { + // Console.WriteLine($"Attempt {retryAttempt}: Retrying AzureCloudConfig creation due to error: {exception.Message}. {exception}"); + // }); + + // using var httpClient = new HttpClient(); + + // return await retryPolicy.ExecuteAsync(async () => + // { + // var httpResponse = await httpClient.GetAsync(azureCloudMetadataUrl); + // httpResponse.EnsureSuccessStatusCode(); + // var jsonString = await httpResponse.Content.ReadAsStringAsync(); + // var config = JsonSerializer.Deserialize(jsonString, AzureCloudConfigContext.Default.AzureCloudConfig)!; + // config.DefaultTokenScope = defaultTokenScope; + // config.ArmEnvironment = armEnvironment; + // config.Domain = domain; + // config.AzureEnvironmentConfig = new AzureEnvironmentConfig(config.Authentication?.LoginEndpointUrl, config.DefaultTokenScope, config.Suffixes?.StorageSuffix); + // config.AuthorityHost = new(config.Authentication!.LoginEndpointUrl!); + // return config; + // }); + //} } + + [JsonSerializable(typeof(AzureCloudConfig))] + public partial class AzureCloudConfigContext : JsonSerializerContext + { } } diff --git a/src/CommonUtilities/AzureEnvironmentConfig.cs b/src/CommonUtilities/AzureEnvironmentConfig.cs index 99b2e6978..75de40022 100644 --- a/src/CommonUtilities/AzureEnvironmentConfig.cs +++ b/src/CommonUtilities/AzureEnvironmentConfig.cs @@ -1,12 +1,16 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using CommonUtilities.AzureCloud; + namespace CommonUtilities { - public class AzureEnvironmentConfig + public record class AzureEnvironmentConfig(string? AzureAuthorityHostUrl, string? TokenScope, string? StorageUrlSuffix) { - public string? AzureAuthorityHostUrl { get; set; } - public string? TokenScope { get; set; } - public string? StorageUrlSuffix { get; set; } + public static AzureEnvironmentConfig FromArmEnvironmentEndpoints(AzureCloudConfig azureCloudConfig) + { + ArgumentNullException.ThrowIfNull(azureCloudConfig); + return new(azureCloudConfig.Authentication?.LoginEndpointUrl, azureCloudConfig.DefaultTokenScope, azureCloudConfig.Suffixes?.StorageSuffix); + } } } diff --git a/src/CommonUtilities/AzureServicesConnectionStringCredential.cs b/src/CommonUtilities/AzureServicesConnectionStringCredential.cs new file mode 100644 index 000000000..04af4ff68 --- /dev/null +++ b/src/CommonUtilities/AzureServicesConnectionStringCredential.cs @@ -0,0 +1,515 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//using System.Security.Cryptography.X509Certificates; +using System.Text.RegularExpressions; +using Azure.Core; +using CommonUtilities.AzureCloud; +using Microsoft.Extensions.Configuration; + +namespace CommonUtilities +{ + /// + /// Enables authentication to Microsoft Entra ID using AzureServicesAuthConnectionString syntax + /// + /// This is adapted from Microsoft.Azure.Services.AppAuthentication.AzureServiceTokenProvider and Azure.Identity.EnvironmentCredential + public class AzureServicesConnectionStringCredential : TokenCredential + { + private readonly TokenCredential credential; + //private readonly AzureServicesConnectionStringCredentialOptions options; + + public AzureServicesConnectionStringCredential(AzureServicesConnectionStringCredentialOptions options) + { + ArgumentNullException.ThrowIfNull(options); + //options.Validate(); + + //this.options = options; + this.credential = AzureServicesConnectionStringCredentialFactory.Create(options); + } + + public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + return credential.GetToken(requestContext, cancellationToken); + } + + public override ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) + { + return credential.GetTokenAsync(requestContext, cancellationToken); + } + } + + /// + /// Options used to configure the . + /// + /// This is adapted from Azure.Identity.EnvironmentCredentialOptions + public class AzureServicesConnectionStringCredentialOptions : Azure.Identity.TokenCredentialOptions + { + [Microsoft.Extensions.DependencyInjection.ActivatorUtilitiesConstructor] + public AzureServicesConnectionStringCredentialOptions(IConfiguration? configuration, AzureCloudConfig armEndpoints) + : this() + { + Configuration = configuration; + SetInitialState(armEndpoints); + ConnectionString = GetEnvironmentVariable("AzureServicesAuthConnectionString")!; + } + + public AzureServicesConnectionStringCredentialOptions(string connectionString, AzureCloudConfig armEndpoints) + : this() + { + SetInitialState(armEndpoints); + ConnectionString = connectionString; + } + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + private AzureServicesConnectionStringCredentialOptions() +#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable. + { + AdditionallyAllowedTenants = []; + } + + private void SetInitialState(AzureCloudConfig armEndpoints) + { + (GetEnvironmentVariable("AZURE_ADDITIONALLY_ALLOWED_TENANTS") ?? string.Empty).Split([';'], StringSplitOptions.RemoveEmptyEntries).ForEach(AdditionallyAllowedTenants.Add); + TenantId = GetEnvironmentVariable("AZURE_TENANT_ID")!; + AuthorityHost = armEndpoints.AuthorityHost ?? new(armEndpoints.Authentication?.LoginEndpointUrl ?? throw new ArgumentException("AuthorityHost is missing", nameof(armEndpoints))); + Audience = armEndpoints.ArmEnvironment?.Audience ?? armEndpoints.Authentication?.Audiences?.LastOrDefault() ?? throw new ArgumentException("Audience is missing", nameof(armEndpoints)); + Resource = new(armEndpoints.ResourceManagerUrl ?? throw new ArgumentException("ResourceManager is missing", nameof(armEndpoints))); + + if (string.IsNullOrWhiteSpace(TenantId)) + { + TenantId = armEndpoints.Authentication?.Tenant ?? throw new ArgumentException("TenantId is missing", nameof(armEndpoints)); + } + } + + /// + /// The connection string to connect to azure services. This value defaults to the value of the environment variable AzureServicesAuthConnectionString. + /// + public string ConnectionString { get; set; } + + /// + /// The authentication audience. This value is related to Azure.ResourceManager.ArmEnvironment.DefaultScope. + /// + public string Audience { get; set; } + + /// + /// The azure service for which to obtain credentials. + /// + public Uri Resource { get; set; } + + /// + /// The ID of the tenant to which the credential will authenticate by default. This value defaults to the value of the environment variable AZURE_TENANT_ID. + /// + public string TenantId { get; set; } + + /// + /// Gets or sets the setting which determines whether or not instance discovery is performed when attempting to authenticate. + /// Setting this to true will completely disable both instance discovery and authority validation. + /// This functionality is intended for use in scenarios where the metadata endpoint cannot be reached, such as in private clouds or Azure Stack. + /// The process of instance discovery entails retrieving authority metadata from https://login.microsoft.com/ to validate the authority. + /// By setting this to true, the validation of the authority is disabled. + /// As a result, it is crucial to ensure that the configured authority host is valid and trustworthy." + /// + public bool DisableInstanceDiscovery { get; set; } + + /// + /// Specifies tenants in addition to the specified for which the credential may acquire tokens. + /// Add the wildcard value "*" to allow the credential to acquire tokens for any tenant the logged in account can access. + /// If no value is specified for , this option will have no effect on that authentication method, and the credential will acquire tokens for any requested tenant when using that method. + /// This value defaults to the value of the environment variable AZURE_ADDITIONALLY_ALLOWED_TENANTS. + /// + public IList AdditionallyAllowedTenants { get; } + + internal IConfiguration? Configuration { get; } + + private string? GetEnvironmentVariable(string key) => GetConfigurationVariable(key) ?? Environment.GetEnvironmentVariable(key); + + private string? GetConfigurationVariable(string key) => Configuration is null ? default : Configuration[key]; + + //internal void Validate() + //{ + // throw new NotImplementedException(); + //} + + internal Azure.Identity.AzureCliCredential CreateAzureCliCredential() + { + var result = new Azure.Identity.AzureCliCredentialOptions { TenantId = TenantId, AuthorityHost = AuthorityHost, IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled }; + CopyAdditionallyAllowedTenants(result.AdditionallyAllowedTenants); + return new(result); + } + + internal Azure.Identity.VisualStudioCredential CreateVisualStudioCredential() + { + var result = new Azure.Identity.VisualStudioCredentialOptions { TenantId = TenantId, AuthorityHost = AuthorityHost, IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled }; + CopyAdditionallyAllowedTenants(result.AdditionallyAllowedTenants); + return new(result); + } + + internal Azure.Identity.VisualStudioCodeCredential CreateVisualStudioCodeCredential() + { + var result = new Azure.Identity.VisualStudioCodeCredentialOptions { TenantId = TenantId, AuthorityHost = AuthorityHost, IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled }; + CopyAdditionallyAllowedTenants(result.AdditionallyAllowedTenants); + return new(result); + } + + //internal Azure.Identity.InteractiveBrowserCredential CreateInteractiveBrowserCredential() + //{ + // var result = new Azure.Identity.InteractiveBrowserCredentialOptions { TenantId = TenantId, AuthorityHost = AuthorityHost, IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled, DisableInstanceDiscovery = DisableInstanceDiscovery }; + // CopyAdditionallyAllowedTenants(result.AdditionallyAllowedTenants); + // return new(result); + //} + + //internal Azure.Identity.ClientCertificateCredential CreateClientCertificateCredential(string appId, X509Certificate2 certificate, string tenantId) + //{ + // var result = new Azure.Identity.ClientCertificateCredentialOptions { AuthorityHost = AuthorityHost, IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled, DisableInstanceDiscovery = DisableInstanceDiscovery }; + // CopyAdditionallyAllowedTenants(result.AdditionallyAllowedTenants); + // return new(string.IsNullOrEmpty(tenantId) ? TenantId : tenantId, appId, certificate, result); + //} + + internal Azure.Identity.ClientSecretCredential CreateClientSecretCredential(string appId, string appKey, string tenantId) + { + var result = new Azure.Identity.ClientSecretCredentialOptions { AuthorityHost = AuthorityHost, IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled, DisableInstanceDiscovery = DisableInstanceDiscovery }; + CopyAdditionallyAllowedTenants(result.AdditionallyAllowedTenants); + return new(string.IsNullOrEmpty(tenantId) ? TenantId : tenantId, appId, appKey, result); + } + + internal Azure.Identity.ManagedIdentityCredential CreateManagedIdentityCredential(int _1, string appId) + { + return new(appId, this); + } + + internal Azure.Identity.ManagedIdentityCredential CreateManagedIdentityCredential(int _1) + { + return new(null, this); + } + + void CopyAdditionallyAllowedTenants(IList additionalTenants) + { + foreach (var tenant in AdditionallyAllowedTenants) + { + additionalTenants.Add(tenant); + } + } + } + + // adapted from https://github.com/Azure/azure-sdk-for-net/blob/main/sdk/mgmtcommon/AppAuthentication/Azure.Services.AppAuthentication/AzureServiceTokenProviderFactory.cs + // Implements https://learn.microsoft.com/en-us/dotnet/api/overview/azure/app-auth-migration?view=azure-dotnet + internal partial class AzureServicesConnectionStringCredentialFactory + { + private const string RunAs = "RunAs"; + private const string Developer = "Developer"; + private const string AzureCli = "AzureCLI"; + private const string VisualStudio = "VisualStudio"; + private const string VisualStudioCode = "VisualStudioCode"; + private const string DeveloperTool = "DeveloperTool"; + private const string CurrentUser = "CurrentUser"; + private const string App = "App"; + private const string AppId = "AppId"; + private const string AppKey = "AppKey"; + private const string TenantId = "TenantId"; + private const string CertificateSubjectName = "CertificateSubjectName"; + private const string CertificateThumbprint = "CertificateThumbprint"; + private const string KeyVaultCertificateSecretIdentifier = "KeyVaultCertificateSecretIdentifier"; + //private const string KeyVaultUserAssignedManagedIdentityId = "KeyVaultUserAssignedManagedIdentityId"; + private const string CertificateStoreLocation = "CertificateStoreLocation"; + private const string MsiRetryTimeout = "MsiRetryTimeout"; + + // taken from https://github.com/dotnet/corefx/blob/master/src/Common/src/System/Data/Common/DbConnectionOptions.Common.cs + [GeneratedRegex( // may not contain embedded null except trailing last value + "([\\s;]*" // leading whitespace and extra semicolons + + "(?![\\s;])" // key does not start with space or semicolon + + "(?([^=\\s\\p{Cc}]|\\s+[^=\\s\\p{Cc}]|\\s+==|==)+)" // allow any visible character for keyname except '=' which must quoted as '==' + + "\\s*=(?!=)\\s*" // the equal sign divides the key and value parts + + "(?" + + "(\"([^\"\u0000]|\"\")*\")" // double quoted string, " must be quoted as "" + + "|" + + "('([^'\u0000]|'')*')" // single quoted string, ' must be quoted as '' + + "|" + + "((?![\"'\\s])" // unquoted value must not start with " or ' or space, would also like = but too late to change + + "([^;\\s\\p{Cc}]|\\s+[^;\\s\\p{Cc}])*" // control characters must be quoted + + "(? { CertificateSubjectName, CertificateThumbprint }, options.ConnectionString); + //ValidateAttribute(connectionSettings, CertificateStoreLocation, options.ConnectionString); + //ValidateStoreLocation(connectionSettings, options.ConnectionString); + //ValidateAttribute(connectionSettings, TenantId, options.ConnectionString); + + //azureServiceTokenCredential = + // options.CreateClientCertificateCredential( + // appId, + // GetCertificates( + // connectionSettings.ContainsKey(CertificateThumbprint) + // ? connectionSettings[CertificateThumbprint] + // : connectionSettings[CertificateSubjectName], + // connectionSettings.ContainsKey(CertificateThumbprint), + // Enum.Parse(connectionSettings[CertificateStoreLocation])) + // .Single(), + // connectionSettings[TenantId]); + + throw new ArgumentException("Connection string " + connectionString + " is not supported. CertificateStoreLocation is deprecated."); + } + else if (connectionSettings.ContainsKey(CertificateThumbprint) || + connectionSettings.ContainsKey(CertificateSubjectName)) + { + // if certificate thumbprint or subject name are specified but certificate store location is not, throw error + throw new ArgumentException($"Connection string {connectionString} is not valid. Must contain '{CertificateStoreLocation}' attribute and it must not be empty " + + $"when using '{CertificateThumbprint}' and '{CertificateSubjectName}' attributes"); + } + else if (connectionSettings.ContainsKey(KeyVaultCertificateSecretIdentifier)) + { + throw new ArgumentException("Connection string " + connectionString + " is not supported. KeyVaultCertificateSecretIdentifier is deprecated."); + + //ValidateMsiRetryTimeout(connectionSettings, options.ConnectionString); + + //var msiRetryTimeout = connectionSettings.ContainsKey(MsiRetryTimeout) + // ? int.Parse(connectionSettings[MsiRetryTimeout]) + // : 0; + //connectionSettings.TryGetValue(KeyVaultUserAssignedManagedIdentityId, out var keyVaultUserAssignedManagedIdentityId); + + //azureServiceTokenCredential = + // new ClientCertificateAzureServiceTokenProvider( + // appId, + // connectionSettings[KeyVaultCertificateSecretIdentifier], + // ClientCertificateAzureServiceTokenProvider.CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, + // null, // storeLocation unused + // azureAdInstance, + // connectionSettings.ContainsKey(TenantId) // tenantId can be specified in connection string or retrieved from Key Vault access token later + // ? connectionSettings[TenantId] + // : default, + //msiRetryTimeout, + // keyVaultUserAssignedManagedIdentityId, + // new AdalAuthenticationContext(httpClientFactory)); + } + else if (connectionSettings.TryGetValue(AppKey, out var appKey)) + { + ValidateAttribute(connectionSettings, TenantId, options.ConnectionString); + + azureServiceTokenCredential = + options.CreateClientSecretCredential( + appId, + appKey, + connectionSettings[TenantId]); + } + else + { + ValidateMsiRetryTimeout(connectionSettings, options.ConnectionString); + + // If certificate or client secret are not specified, use the specified managed identity + azureServiceTokenCredential = options.CreateManagedIdentityCredential( + connectionSettings.TryGetValue(MsiRetryTimeout, out var value) + ? int.Parse(value) + : 0, + appId); + } + } + else + { + ValidateMsiRetryTimeout(connectionSettings, options.ConnectionString); + + // If AppId is not specified, use Managed Service Identity + azureServiceTokenCredential = options.CreateManagedIdentityCredential( + connectionSettings.TryGetValue(MsiRetryTimeout, out var value) + ? int.Parse(value) + : 0); + } + } + else + { + throw new ArgumentException($"Connection string {connectionString} is not valid. RunAs value '{connectionSettings[RunAs]}' is not valid. " + + $"Allowed values are {Developer}, {CurrentUser}, or {App}"); + } + + return azureServiceTokenCredential; + + } + + //public static List GetCertificates(string subjectNameOrThumbprint, bool isThumbprint, StoreLocation location) + //{ + // var x509Store = new X509Store(StoreName.My, location); + // x509Store.Open(OpenFlags.ReadOnly); + // return x509Store.Certificates + // .Where(current => current is not null && current.HasPrivateKey && (isThumbprint && string.Equals(subjectNameOrThumbprint, current.Thumbprint, StringComparison.OrdinalIgnoreCase) || !isThumbprint && string.Equals(subjectNameOrThumbprint, current.Subject, StringComparison.OrdinalIgnoreCase))) + // .ToList(); + //} + + private static void ValidateAttribute(Dictionary connectionSettings, string attribute, string connectionString) + { + if (connectionSettings != null && + (!connectionSettings.ContainsKey(attribute) || string.IsNullOrWhiteSpace(connectionSettings[attribute]))) + { + throw new ArgumentException($"Connection string {connectionString} is not valid. Must contain '{attribute}' attribute and it must not be empty.", nameof(connectionString)); + } + } + + ///// + ///// Throws an exception if none of the attributes are in the connection string + ///// + ///// List of key value pairs in the connection string + ///// List of attributes to test + ///// The connection string specified + //private static void ValidateAttributes(Dictionary connectionSettings, List attributes, string connectionString) + //{ + // if (connectionSettings != null) + // { + // foreach (var attribute in attributes) + // { + // if (connectionSettings.ContainsKey(attribute)) + // { + // return; + // } + // } + + // throw new ArgumentException($"Connection string {connectionString} is not valid. Must contain at least one of {string.Join(" or ", attributes)} attributes.", nameof(connectionString)); + // } + //} + + //private static void ValidateStoreLocation(Dictionary connectionSettings, string connectionString) + //{ + // if (connectionSettings != null && connectionSettings.TryGetValue(CertificateStoreLocation, out var storeLocation)) + // { + // if (!string.IsNullOrWhiteSpace(storeLocation)) + // { + // if (!Enum.TryParse(storeLocation, true, out var _)) + // { + // throw new ArgumentException( + // $"Connection string {connectionString} is not valid. StoreLocation {storeLocation} is not valid. Valid values are CurrentUser and LocalMachine."); + // } + // } + // } + //} + + private static void ValidateMsiRetryTimeout(Dictionary connectionSettings, string connectionString) + { + if (connectionSettings != null && connectionSettings.TryGetValue(MsiRetryTimeout, out var value)) + { + if (!string.IsNullOrWhiteSpace(value)) + { + var timeoutString = value; + + var parseSucceeded = int.TryParse(timeoutString, out _); + if (!parseSucceeded) + { + throw new ArgumentException( + $"Connection string {connectionString} is not valid. MsiRetryTimeout {timeoutString} is not valid. Valid values are integers greater than or equal to 0."); + } + } + + } + } + + // adapted from https://github.com/dotnet/corefx/blob/master/src/Common/src/System/Data/Common/DbConnectionOptions.Common.cs + internal static Dictionary ParseConnectionString(string connectionString) + { + if (string.IsNullOrWhiteSpace(connectionString)) + { + connectionString = string.Empty; + } + + ArgumentException.ThrowIfNullOrEmpty(connectionString); + + var connectionSettings = new Dictionary(StringComparer.OrdinalIgnoreCase); + const int KeyIndex = 1, ValueIndex = 2; + var match = ConnectionStringPatternRegex.Match(connectionString); + if (!match.Success || match.Length != connectionString.Length) + { + throw new ArgumentException( + $"Connection string {connectionString} is not in a proper format. Expected format is Key1=Value1;Key2=Value2;", nameof(connectionString)); + } + + var indexValue = 0; + var keyValues = match.Groups[ValueIndex].Captures; + foreach (var keyNames in match.Groups[KeyIndex].Captures.Cast()) + { + var key = keyNames.Value.Replace("==", "="); + var value = keyValues[indexValue++].Value; + if (value.Length > 0) + { + switch (value[0]) + { + case '\"': + value = value[1..^1].Replace("\"\"", "\""); + break; + case '\'': + value = value[1..^1].Replace("\'\'", "\'"); + break; + default: + break; + } + } + + if (!string.IsNullOrWhiteSpace(key)) + { + if (!connectionSettings.ContainsKey(key)) + { + connectionSettings[key] = value; + } + else + { + throw new ArgumentException( + $"Connection string {connectionString} is not in a proper format. Key '{key}' is repeated.", nameof(connectionString)); + } + } + } + + return connectionSettings; + } + } +} diff --git a/src/CommonUtilities/CloudEnvironment.cs b/src/CommonUtilities/CloudEnvironment.cs new file mode 100644 index 000000000..600835cd9 --- /dev/null +++ b/src/CommonUtilities/CloudEnvironment.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Identity; +using Azure.ResourceManager; + +namespace CommonUtilities +{ + /// + /// Gets minimum cloud authentication metadata for resource management + /// Gets . + /// The default host of the Microsoft Entra authority for tenants in the Azure Cloud. + /// + public record struct CloudEnvironment(ArmEnvironment ArmEnvironment, Uri AzureAuthorityHost) + { + /// + /// Gets base URI of the management API endpoint. + /// + public readonly Uri Endpoint => ArmEnvironment.Endpoint; + + /// + /// Gets authentication audience. + /// + public readonly string Audience => ArmEnvironment.Audience; + + /// + /// Gets default authentication scope. + /// + public readonly string DefaultScope => ArmEnvironment.DefaultScope; + + private static CloudEnvironment FromLibraries(ArmEnvironment armEnvironment, Uri azureAuthorityHost) + { + ArgumentNullException.ThrowIfNull(armEnvironment); + ArgumentNullException.ThrowIfNull(azureAuthorityHost); + + return new(armEnvironment, azureAuthorityHost); + } + + private static CloudEnvironment FromLibraries(string? name) + => name switch + { + nameof(AzureAuthorityHosts.AzurePublicCloud) => FromLibraries(ArmEnvironment.AzurePublicCloud, AzureAuthorityHosts.AzurePublicCloud), + nameof(AzureAuthorityHosts.AzureGovernment) => FromLibraries(ArmEnvironment.AzureGovernment, AzureAuthorityHosts.AzureGovernment), + nameof(AzureAuthorityHosts.AzureChina) => FromLibraries(ArmEnvironment.AzureChina, AzureAuthorityHosts.AzureChina), + null => FromLibraries(ArmEnvironment.AzurePublicCloud, AzureAuthorityHosts.AzurePublicCloud), + _ => throw new InvalidOperationException("Unknown cloud."), + }; + + public static CloudEnvironment GetCloud(string? cloudName) => + cloudName?.ToLowerInvariant() switch + { + "azurepubliccloud" => FromLibraries(nameof(AzureAuthorityHosts.AzurePublicCloud)), + "azurecloud" => FromLibraries(nameof(AzureAuthorityHosts.AzurePublicCloud)), + "azureusgovernmentcloud" => FromLibraries(nameof(AzureAuthorityHosts.AzureGovernment)), + "azureusgovernment" => FromLibraries(nameof(AzureAuthorityHosts.AzureGovernment)), + "azurechinacloud" => FromLibraries(nameof(AzureAuthorityHosts.AzureChina)), + null => throw new ArgumentNullException(nameof(cloudName)), + _ => throw new ArgumentOutOfRangeException(nameof(cloudName)), + }; + } +} diff --git a/src/CommonUtilities/CommonUtilities.csproj b/src/CommonUtilities/CommonUtilities.csproj index e0f934d99..0488cde70 100644 --- a/src/CommonUtilities/CommonUtilities.csproj +++ b/src/CommonUtilities/CommonUtilities.csproj @@ -5,20 +5,20 @@ enable enable - + - - - - + + + + + - diff --git a/src/CommonUtilities/ExpensiveObjectTestUtility.cs b/src/CommonUtilities/ExpensiveObjectTestUtility.cs index 485982d92..7d3e28491 100644 --- a/src/CommonUtilities/ExpensiveObjectTestUtility.cs +++ b/src/CommonUtilities/ExpensiveObjectTestUtility.cs @@ -10,6 +10,6 @@ namespace CommonUtilities /// public static class ExpensiveObjectTestUtility { - public static AzureCloudConfig AzureCloudConfig = AzureCloudConfig.CreateAsync().Result; + public static AzureCloudConfig AzureCloudConfig = AzureCloudConfig.FromKnownCloudNameAsync().Result; } } diff --git a/src/CommonUtilities/PagedInterfaceExtensions.cs b/src/CommonUtilities/PagedInterfaceExtensions.cs index 07891f8fb..34d7fe63e 100644 --- a/src/CommonUtilities/PagedInterfaceExtensions.cs +++ b/src/CommonUtilities/PagedInterfaceExtensions.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Microsoft.Azure.Management.ResourceManager.Fluent.Core; +using Azure; using Microsoft.Rest.Azure; using static CommonUtilities.RetryHandler; @@ -13,15 +13,14 @@ namespace CommonUtilities public static class PagedInterfaceExtensions { /// - /// Creates an from an + /// Splits an into pages and re-presents it as an to facilitate retry logic /// /// The type of objects to enumerate. - /// The to enumerate. + /// The to enumerate. /// An /// - public static IAsyncEnumerable ToAsyncEnumerable(this IPagedCollection source) - => new AsyncEnumerable(source); - + public static IAsyncEnumerable ToAsyncEnumerable(this AsyncPageable source) where T : notnull + => new AsyncPageableEnumerable(source); /// /// Creates an from an @@ -90,16 +89,44 @@ public static async Task> ExecuteWithRetryAsync(this Asyn } #region Implementation classes - private readonly struct AsyncEnumerable : IAsyncEnumerable + private readonly struct AsyncPageableEnumerable : IAsyncEnumerable where T : notnull { private readonly Func> _getEnumerator; + private readonly AsyncPageable _source; - public AsyncEnumerable(IPagedCollection source) + public AsyncPageableEnumerable(AsyncPageable source) { ArgumentNullException.ThrowIfNull(source); - _getEnumerator = c => new PagedCollectionEnumerator(source, c); + _source = source; + _getEnumerator = c => new AsyncPageableEnumerator(null!, GetNextPage, c); + } + + /// + IAsyncEnumerator IAsyncEnumerable.GetAsyncEnumerator(CancellationToken cancellationToken) + => _getEnumerator(cancellationToken); + + private async Task> GetNextPage(Page page, CancellationToken cancellationToken) + { + var enumerator = (page switch + { + null => _source.AsPages(), + var x when string.IsNullOrEmpty(page.ContinuationToken) => null!, + _ => _source.AsPages(continuationToken: page.ContinuationToken) + })?.GetAsyncEnumerator(cancellationToken); + + if (await (enumerator?.MoveNextAsync(cancellationToken) ?? ValueTask.FromResult(false))) + { + return enumerator!.Current; + } + + return null!; } + } + + private readonly struct AsyncEnumerable : IAsyncEnumerable + { + private readonly Func> _getEnumerator; public AsyncEnumerable(IPage source, Func?>> nextPageFunc) { @@ -174,10 +201,12 @@ public PageEnumerator(IPage source, Func : PagingEnumerator> + private sealed class AsyncPageableEnumerator : PagingEnumerator> where T : notnull { - public PagedCollectionEnumerator(IPagedCollection source, CancellationToken cancellationToken) - : base(source, s => s.GetEnumerator(), (s, ct) => s.GetNextPageAsync(ct), cancellationToken) + public AsyncPageableEnumerator(Page source, Func, CancellationToken, Task>> nextPageFunc, CancellationToken cancellationToken) +#pragma warning disable CS8619 // Nullability of reference types in value doesn't match target type. + : base(source, s => (s.Values ?? []).GetEnumerator(), (s, ct) => nextPageFunc(s, ct), cancellationToken) +#pragma warning restore CS8619 // Nullability of reference types in value doesn't match target type. { } } diff --git a/src/CommonUtilities/RefreshableAzureServiceTokenProvider.cs b/src/CommonUtilities/RefreshableAzureServiceTokenProvider.cs deleted file mode 100644 index 855e7d1c0..000000000 --- a/src/CommonUtilities/RefreshableAzureServiceTokenProvider.cs +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Net.Http.Headers; -using Microsoft.Azure.Services.AppAuthentication; -using Microsoft.Rest; - -namespace CommonUtilities -{ - /// - /// ITokenProvider implementation based on AzureServiceTokenProvider from Microsoft.Azure.Services.AppAuthentication package. - /// - public class RefreshableAzureServiceTokenProvider : ITokenProvider - { - private readonly string resource; - private readonly string? tenantId; - private readonly AzureServiceTokenProvider tokenProvider; - - /// - /// Constructor. - /// - /// Resource to request tokens for - /// AAD tenant ID containing the resource - /// AAD instance to request tokens from - public RefreshableAzureServiceTokenProvider(string resource, string? tenantId = null, string azureAdInstance = "https://login.microsoftonline.com/") - { - ArgumentException.ThrowIfNullOrEmpty(resource); - ArgumentException.ThrowIfNullOrEmpty(azureAdInstance); - - this.resource = resource; - this.tenantId = tenantId; - - this.tokenProvider = new("RunAs=Developer; DeveloperTool=AzureCli", azureAdInstance: azureAdInstance); - } - - /// - /// Gets the authentication header with token. - /// - /// Cancellation token - /// Authentication header with token - public async Task GetAuthenticationHeaderAsync(CancellationToken cancellationToken) - { - // AzureServiceTokenProvider caches tokens internally and refreshes them before expiry. - // This method usually gets called on every request to set the authentication header. This ensures that we cache tokens, and also that we always get a valid one. - var token = await tokenProvider.GetAccessTokenAsync(resource, tenantId, cancellationToken); - return new("Bearer", token); - } - } -} diff --git a/src/CommonUtilities/UtilityExtensions.cs b/src/CommonUtilities/UtilityExtensions.cs index 23b4f7172..4c3378c2f 100644 --- a/src/CommonUtilities/UtilityExtensions.cs +++ b/src/CommonUtilities/UtilityExtensions.cs @@ -8,7 +8,7 @@ namespace CommonUtilities public static class UtilityExtensions { #region RFC 4648 Base32 - private static readonly char[] Rfc4648Base32 = new[] { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '2', '3', '4', '5', '6', '7' }; + private static readonly char[] Rfc4648Base32 = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '2', '3', '4', '5', '6', '7']; private const int GroupBitlength = 5; private const int BitsPerByte = 8; private const int LargestBitPosition = GroupBitlength - 1; @@ -44,31 +44,32 @@ public static string ConvertToBase32(this byte[] bytes) 2 => @"====", 3 => @"===", 4 => @"=", - _ => throw new InvalidOperationException(), // Keeps the compiler happy. + _ => throw new System.Diagnostics.UnreachableException(), }; #endregion + #region ConvertGroup /// - /// Converts each group (fixed number) of items into a new item + /// Converts each group (by count) of items into a new item /// /// Type of source items /// Intermediate type /// Type of the resultant items /// The source enumerable of type . /// The size of each group to create out of the entire enumeration. The last group may be smaller. - /// The function that prepares each into the value expected by . Its parameters are an item of type and the index of that item (starting from zero) within each group. - /// The function that creates the from each group of items. + /// The function that prepares each into the value expected by . Its parameters are an item of type and the index of that item (starting from zero) within each group. + /// The function that creates the from each group of items. /// An enumeration of from all of the groups. public static IEnumerable ConvertGroup( this IEnumerable source, int groupSize, - Func groupItemFunc, - Func, TResult> groupResultFunc) + Func source2Group, + Func, TResult> group2Result) => source .Select((value, index) => (Index: index, Value: value)) .GroupBy(tuple => tuple.Index / groupSize) .OrderBy(tuple => tuple.Key) - .Select(groups => groupResultFunc(groups.Select(item => groupItemFunc(item.Value, item.Index % groupSize)))); + .Select(groups => group2Result(groups.Select(item => source2Group(item.Value, item.Index % groupSize)))); /// /// Performs on each item in . @@ -101,5 +102,24 @@ public static void ForEach(this IEnumerable values, Action action) action(item); } } + #endregion + + #region AddRange + //public static void AddRange(this IList list, IEnumerable values) + //{ + // foreach (var value in values) + // { + // list.Add(value); + // }; + //} + + public static void AddRange(this IDictionary dictionary, IDictionary values) + { + foreach (var value in values) + { + dictionary.Add(value); + }; + } + #endregion } } diff --git a/src/GenerateBatchVmSkus/GenerateBatchVmSkus.csproj b/src/GenerateBatchVmSkus/GenerateBatchVmSkus.csproj index e8ea42c66..b444c6c00 100644 --- a/src/GenerateBatchVmSkus/GenerateBatchVmSkus.csproj +++ b/src/GenerateBatchVmSkus/GenerateBatchVmSkus.csproj @@ -9,7 +9,7 @@ - + diff --git a/src/Tes.ApiClients.Tests/DrsHubApiClientTests.cs b/src/Tes.ApiClients.Tests/DrsHubApiClientTests.cs index 42881ff1e..80ef8f242 100644 --- a/src/Tes.ApiClients.Tests/DrsHubApiClientTests.cs +++ b/src/Tes.ApiClients.Tests/DrsHubApiClientTests.cs @@ -29,10 +29,7 @@ public void Setup() cachingRetryPolicyBuilder = new CachingRetryPolicyBuilder(appCache, Options.Create(retryPolicyOptions)); tokenCredentialsMock = new Mock(); - azureEnvironmentConfig = new CommonUtilities.AzureEnvironmentConfig() - { - TokenScope = "https://management.azure.com/.default" - }; + azureEnvironmentConfig = new CommonUtilities.AzureEnvironmentConfig(default, "https://management.azure.com/.default", default); apiClient = new DrsHubApiClient(DrsApiHost, tokenCredentialsMock.Object, cachingRetryPolicyBuilder, azureEnvironmentConfig, NullLogger.Instance); } @@ -57,8 +54,10 @@ public async Task GetDrsResolveRequestContent_ValidDrsUri_ReturnsValidRequestCon [TestMethod] public async Task GetDrsResolveApiResponse_ResponseWithAccessUrl_CanDeserializeJSon() { - var httpResponse = new HttpResponseMessage(System.Net.HttpStatusCode.OK); - httpResponse.Content = new StringContent(ExpectedRsResolveResponseJson); + HttpResponseMessage httpResponse = new(System.Net.HttpStatusCode.OK) + { + Content = new StringContent(ExpectedRsResolveResponseJson) + }; var drsResolveResponse = await DrsHubApiClient.GetDrsResolveApiResponseAsync(httpResponse, CancellationToken.None); @@ -75,9 +74,9 @@ public async Task GetDrsResolveApiResponse_ResponseWithAccessUrl_CanDeserializeJ }"; private const string ExpectedDrsResolveRequestJson = @"{ - ""url"": ""drs://drs.foo"", - ""cloudPlatform"": ""azure"", - ""fields"":[""accessUrl""] + ""url"": ""drs://drs.foo"", + ""cloudPlatform"": ""azure"", + ""fields"":[""accessUrl""] }"; } } diff --git a/src/Tes.ApiClients.Tests/TerraIntegration/TerraWsmApiClientIntegrationTests.cs b/src/Tes.ApiClients.Tests/TerraIntegration/TerraWsmApiClientIntegrationTests.cs index c958d4264..fec39ef10 100644 --- a/src/Tes.ApiClients.Tests/TerraIntegration/TerraWsmApiClientIntegrationTests.cs +++ b/src/Tes.ApiClients.Tests/TerraIntegration/TerraWsmApiClientIntegrationTests.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using CommonUtilities; -using CommonUtilities.AzureCloud; using CommonUtilities.Options; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Logging; @@ -23,10 +22,9 @@ public void Setup() var retryOptions = Microsoft.Extensions.Options.Options.Create(new RetryPolicyOptions()); var memoryCache = new MemoryCache(new MemoryCacheOptions()); - var config = ExpensiveObjectTestUtility.AzureCloudConfig.AzureEnvironmentConfig; + var config = AzureEnvironmentConfig.FromArmEnvironmentEndpoints(CommonUtilities.AzureCloud.AzureCloudConfig.FromKnownCloudNameAsync(retryPolicyOptions: retryOptions).Result); wsmApiClient = new TerraWsmApiClient(envInfo.WsmApiHost, new TestEnvTokenCredential(), new CachingRetryPolicyBuilder(memoryCache, retryOptions), config, TestLoggerFactory.Create()); - } [TestMethod] diff --git a/src/Tes.ApiClients.Tests/TerraWsmApiClientTests.cs b/src/Tes.ApiClients.Tests/TerraWsmApiClientTests.cs index 1a51ce232..3e672d8c1 100644 --- a/src/Tes.ApiClients.Tests/TerraWsmApiClientTests.cs +++ b/src/Tes.ApiClients.Tests/TerraWsmApiClientTests.cs @@ -32,7 +32,7 @@ public void SetUp() cache.Setup(c => c.CreateEntry(It.IsAny())).Returns(new Mock().Object); cacheAndRetryBuilder.SetupGet(c => c.AppCache).Returns(cache.Object); cacheAndRetryHandler = new(TestServices.RetryHandlersHelpers.GetCachingAsyncRetryPolicyMock(cacheAndRetryBuilder, c => c.DefaultRetryHttpResponseMessagePolicyBuilder())); - azureEnvironmentConfig = ExpensiveObjectTestUtility.AzureCloudConfig.AzureEnvironmentConfig!; + azureEnvironmentConfig = AzureEnvironmentConfig.FromArmEnvironmentEndpoints(CommonUtilities.AzureCloud.AzureCloudConfig.FromKnownCloudNameAsync(CommonUtilities.AzureCloud.AzureCloudConfig.DefaultAzureCloudName).Result); terraWsmApiClient = new TerraWsmApiClient(TerraApiStubData.WsmApiHost, tokenCredential.Object, cacheAndRetryBuilder.Object, azureEnvironmentConfig, NullLogger.Instance); diff --git a/src/Tes.ApiClients.Tests/Tes.ApiClients.Tests.csproj b/src/Tes.ApiClients.Tests/Tes.ApiClients.Tests.csproj index f01e8e866..93a566dde 100644 --- a/src/Tes.ApiClients.Tests/Tes.ApiClients.Tests.csproj +++ b/src/Tes.ApiClients.Tests/Tes.ApiClients.Tests.csproj @@ -13,10 +13,10 @@ - + - - + + diff --git a/src/Tes.ApiClients/Tes.ApiClients.csproj b/src/Tes.ApiClients/Tes.ApiClients.csproj index d82f89ede..645acf9b7 100644 --- a/src/Tes.ApiClients/Tes.ApiClients.csproj +++ b/src/Tes.ApiClients/Tes.ApiClients.csproj @@ -7,7 +7,7 @@ - + diff --git a/src/Tes.Runner.Test/Storage/ArmUrlTransformationStrategyTests.cs b/src/Tes.Runner.Test/Storage/ArmUrlTransformationStrategyTests.cs index 7ba325d9c..a54b56677 100644 --- a/src/Tes.Runner.Test/Storage/ArmUrlTransformationStrategyTests.cs +++ b/src/Tes.Runner.Test/Storage/ArmUrlTransformationStrategyTests.cs @@ -5,6 +5,7 @@ using Azure.Storage.Blobs.Models; using Azure.Storage.Sas; using CommonUtilities; +using CommonUtilities.Options; using Moq; using Tes.Runner.Models; using Tes.Runner.Storage; @@ -26,7 +27,7 @@ public void SetUp() mockBlobServiceClient = new Mock(); RuntimeOptions options = new() { - AzureEnvironmentConfig = ExpensiveObjectTestUtility.AzureCloudConfig.AzureEnvironmentConfig + AzureEnvironmentConfig = AzureEnvironmentConfig.FromArmEnvironmentEndpoints(CommonUtilities.AzureCloud.AzureCloudConfig.FromKnownCloudNameAsync().Result) }; armUrlTransformationStrategy = new ArmUrlTransformationStrategy(_ => mockBlobServiceClient.Object, options); diff --git a/src/Tes.Runner.Test/Storage/UrlTransformationStrategyFactoryTests.cs b/src/Tes.Runner.Test/Storage/UrlTransformationStrategyFactoryTests.cs index 1908e69da..e25564635 100644 --- a/src/Tes.Runner.Test/Storage/UrlTransformationStrategyFactoryTests.cs +++ b/src/Tes.Runner.Test/Storage/UrlTransformationStrategyFactoryTests.cs @@ -83,18 +83,14 @@ public void CreateCombinedTerraTransformationStrategy_WithDrsHubSettings_Returns private static RuntimeOptions CreateRuntimeOptions() { - return new RuntimeOptions + return new() { - AzureEnvironmentConfig = new AzureEnvironmentConfig() - { - StorageUrlSuffix = @"core.windows.net", - TokenScope = ".default" - } + AzureEnvironmentConfig = new AzureEnvironmentConfig(default, ".default", @"core.windows.net") }; } private static TerraRuntimeOptions CreateTerraRuntimeOptions(string? drsHubHost = default) { - return new TerraRuntimeOptions + return new() { DrsHubApiHost = drsHubHost, WsmApiHost = "https://wsmhost.bio", diff --git a/src/Tes.Runner.Test/Tes.Runner.Test.csproj b/src/Tes.Runner.Test/Tes.Runner.Test.csproj index 3eb97aefd..eba2f2477 100644 --- a/src/Tes.Runner.Test/Tes.Runner.Test.csproj +++ b/src/Tes.Runner.Test/Tes.Runner.Test.csproj @@ -20,8 +20,8 @@ - - + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/src/Tes.Runner/Tes.Runner.csproj b/src/Tes.Runner/Tes.Runner.csproj index 2469f5f08..9d9742b32 100644 --- a/src/Tes.Runner/Tes.Runner.csproj +++ b/src/Tes.Runner/Tes.Runner.csproj @@ -8,7 +8,7 @@ - + diff --git a/src/Tes.Runner/Transfer/SimpleScalingStrategy.cs b/src/Tes.Runner/Transfer/SimpleScalingStrategy.cs index 732ad3f8d..5638dad0d 100644 --- a/src/Tes.Runner/Transfer/SimpleScalingStrategy.cs +++ b/src/Tes.Runner/Transfer/SimpleScalingStrategy.cs @@ -1,12 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - namespace Tes.Runner.Transfer { public class SimpleScalingStrategy : IScalingStrategy diff --git a/src/Tes.SDK.Tests/Tes.SDK.Tests.csproj b/src/Tes.SDK.Tests/Tes.SDK.Tests.csproj index bc40c29d4..d542b3d0d 100644 --- a/src/Tes.SDK.Tests/Tes.SDK.Tests.csproj +++ b/src/Tes.SDK.Tests/Tes.SDK.Tests.csproj @@ -14,10 +14,10 @@ all runtime; build; native; contentfiles; analyzers; buildtransitive - + - - + + diff --git a/src/Tes/Tes.csproj b/src/Tes/Tes.csproj index f204740d0..cdac5b7f9 100644 --- a/src/Tes/Tes.csproj +++ b/src/Tes/Tes.csproj @@ -18,7 +18,6 @@ - diff --git a/src/TesApi.Tests/BatchPoolTests.cs b/src/TesApi.Tests/BatchPoolTests.cs index 95a37c7b2..0feae1ed6 100644 --- a/src/TesApi.Tests/BatchPoolTests.cs +++ b/src/TesApi.Tests/BatchPoolTests.cs @@ -6,12 +6,12 @@ using System.Linq; using System.Threading.Tasks; using Microsoft.Azure.Batch; -using Microsoft.Azure.Management.Batch.Models; using Microsoft.VisualStudio.TestTools.UnitTesting; using Moq; using Tes.Models; using TesApi.Web; using TesApi.Web.Management; +using TesApi.Web.Management.Batch; using TesApi.Web.Management.Models.Quotas; namespace TesApi.Tests @@ -141,13 +141,21 @@ private static TestServices.TestServiceProvider GetServiceProvid wrapAzureProxy: true, configuration: GetMockConfig(), azureProxy: PrepareMockAzureProxy(azureProxyReturn), + batchPoolManager: PrepareMockBatchPoolManager(azureProxyReturn), batchQuotaProvider: GetMockQuotaProvider(azureProxyReturn), batchSkuInformationProvider: GetMockSkuInfoProvider(azureProxyReturn), accountResourceInformation: new("defaultbatchaccount", "defaultresourcegroup", "defaultsubscription", "defaultregion", "defaultendpoint")); } private static async Task AddPool(BatchScheduler batchPools, bool isPreemtable) - => (BatchPool)await batchPools.GetOrAddPoolAsync("key1", isPreemtable, (id, _1) => ValueTask.FromResult(new Pool(name: id, displayName: "display1", vmSize: "vmSize1")), System.Threading.CancellationToken.None); + => (BatchPool)await batchPools.GetOrAddPoolAsync("key1", isPreemtable, (id, _1) => ValueTask.FromResult(CreatePoolData(name: id, displayName: "display1", vmSize: "vmSize1")), System.Threading.CancellationToken.None); + + internal static Azure.ResourceManager.Batch.BatchAccountPoolData CreatePoolData(string name, string displayName = default, string vmSize = default) + { + Azure.ResourceManager.Batch.BatchAccountPoolData result = new() { DisplayName = displayName, VmSize = vmSize }; + result.Metadata.Add(new(string.Empty, name)); + return result; + } private static void TimeShift(TimeSpan shift, BatchPool pool) => pool.TimeShift(shift); @@ -157,7 +165,7 @@ private class AzureProxyReturnValues internal static AzureProxyReturnValues Get() => new(); - internal AzureBatchAccountQuotas BatchQuotas { get; set; } = new() { PoolQuota = 1, ActiveJobAndJobScheduleQuota = 1, DedicatedCoreQuotaPerVMFamily = new List() }; + internal AzureBatchAccountQuotas BatchQuotas { get; set; } = new() { PoolQuota = 1, ActiveJobAndJobScheduleQuota = 1, DedicatedCoreQuotaPerVMFamily = [] }; internal int ActivePoolCount { get; set; } = 0; internal Func> AzureProxyListComputeNodesAsync { get; set; } = (poolId, detailLevel) => AsyncEnumerable.Empty(); @@ -172,7 +180,7 @@ internal static AzureProxyReturnValues Get() internal bool PoolStateExists(string poolId) => poolState.ContainsKey(poolId); - private readonly Dictionary PoolMetadata)> poolState = []; + private readonly Dictionary PoolMetadata)> poolState = []; internal void SetPoolState( string id, @@ -185,7 +193,7 @@ internal void SetPoolState( Microsoft.Azure.Batch.Protocol.Models.AutoScaleRun autoScaleRun = default, bool? enableAutoScale = default, DateTime? creationTime = default, - IList poolMetadata = default) + IList poolMetadata = default) { if (poolState.TryGetValue(id, out var state)) { @@ -214,7 +222,7 @@ internal void SetPoolState( creationTime ?? state.CreationTime, metadata.Count == 0 ? null : metadata.Select(ConvertMetadata).ToList()); - static Microsoft.Azure.Batch.MetadataItem ConvertMetadata(KeyValuePair pair) + static MetadataItem ConvertMetadata(KeyValuePair pair) => new(pair.Key, pair.Value); } else @@ -229,14 +237,15 @@ internal void AzureProxyDeleteBatchPoolImpl(string poolId, System.Threading.Canc _ = poolState.Remove(poolId); } - internal string CreateBatchPoolImpl(Pool pool) + internal string CreateBatchPoolImpl(Azure.ResourceManager.Batch.BatchAccountPoolData pool) { - var poolId = pool.Name; + var poolIdItem = pool.Metadata.Single(i => string.IsNullOrEmpty(i.Name)); + pool.Metadata.Remove(poolIdItem); - poolState.Add(poolId, (default, default, default, default, Microsoft.Azure.Batch.Common.AllocationState.Steady, default, default, true, default, pool.Metadata?.Select(ConvertMetadata).ToList())); - return poolId; + poolState.Add(poolIdItem.Value, (default, default, default, default, Microsoft.Azure.Batch.Common.AllocationState.Steady, default, default, true, default, pool.Metadata?.Select(ConvertMetadata).ToList())); + return poolIdItem.Value; - static Microsoft.Azure.Batch.MetadataItem ConvertMetadata(Microsoft.Azure.Management.Batch.Models.MetadataItem item) + static MetadataItem ConvertMetadata(Azure.ResourceManager.Batch.Models.BatchAccountPoolMetadataItem item) => new(item.Name, item.Value); } @@ -292,30 +301,35 @@ private static Action> GetMockQuotaProvider(AzureProxy .ReturnsAsync(new BatchVmCoreQuota(batchQuotas.LowPriorityCoreQuota, true, batchQuotas.DedicatedCoreQuotaPerVMFamilyEnforced, - batchQuotas.DedicatedCoreQuotaPerVMFamily?.Select(v => new BatchVmCoresPerFamily(v.Name, v.CoreQuota)).ToList(), + batchQuotas.DedicatedCoreQuotaPerVMFamily?.Select(v => new BatchVmCoresPerFamily(v.Name, v.CoreQuota ?? 0)).ToList(), new(batchQuotas.ActiveJobAndJobScheduleQuota, batchQuotas.PoolQuota, batchQuotas.DedicatedCoreQuota, batchQuotas.LowPriorityCoreQuota))); quotaProvider.Setup(p => p.GetVmCoreQuotaAsync(It.Is(l => l == false), It.IsAny())) .ReturnsAsync(new BatchVmCoreQuota(batchQuotas.DedicatedCoreQuota, false, batchQuotas.DedicatedCoreQuotaPerVMFamilyEnforced, - batchQuotas.DedicatedCoreQuotaPerVMFamily?.Select(v => new BatchVmCoresPerFamily(v.Name, v.CoreQuota)).ToList(), + batchQuotas.DedicatedCoreQuotaPerVMFamily?.Select(v => new BatchVmCoresPerFamily(v.Name, v.CoreQuota ?? 0)).ToList(), new(batchQuotas.ActiveJobAndJobScheduleQuota, batchQuotas.PoolQuota, batchQuotas.DedicatedCoreQuota, batchQuotas.LowPriorityCoreQuota))); }); + private static Action> PrepareMockBatchPoolManager(AzureProxyReturnValues azureProxyReturnValues) + => azureProxy => + { + azureProxy.Setup(a => a.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())).Returns((Azure.ResourceManager.Batch.BatchAccountPoolData p, bool _1, System.Threading.CancellationToken _2) => Task.FromResult(azureProxyReturnValues.CreateBatchPoolImpl(p))); + azureProxy.Setup(a => a.DeleteBatchPoolAsync(It.IsAny(), It.IsAny())).Callback((poolId, cancellationToken) => azureProxyReturnValues.AzureProxyDeleteBatchPoolImpl(poolId, cancellationToken)).Returns(Task.CompletedTask); + }; + private static Action> PrepareMockAzureProxy(AzureProxyReturnValues azureProxyReturnValues) => azureProxy => { azureProxy.Setup(a => a.GetActivePoolsAsync(It.IsAny())).Returns(AsyncEnumerable.Empty()); azureProxy.Setup(a => a.GetBatchActivePoolCount()).Returns(azureProxyReturnValues.ActivePoolCount); - azureProxy.Setup(a => a.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())).Returns((Pool p, bool _1, System.Threading.CancellationToken _2) => Task.FromResult(azureProxyReturnValues.CreateBatchPoolImpl(p))); azureProxy.Setup(a => a.ListComputeNodesAsync(It.IsAny(), It.IsAny())).Returns((poolId, detailLevel) => azureProxyReturnValues.AzureProxyListComputeNodesAsync(poolId, detailLevel)); azureProxy.Setup(a => a.ListTasksAsync(It.IsAny(), It.IsAny())).Returns((jobId, detailLevel) => azureProxyReturnValues.AzureProxyListTasks(jobId, detailLevel)); azureProxy.Setup(a => a.DeleteBatchComputeNodesAsync(It.IsAny(), It.IsAny>(), It.IsAny())).Callback, System.Threading.CancellationToken>((poolId, computeNodes, cancellationToken) => azureProxyReturnValues.AzureProxyDeleteBatchComputeNodes(poolId, computeNodes, cancellationToken)).Returns(Task.CompletedTask); azureProxy.Setup(a => a.GetBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())).Returns((string id, System.Threading.CancellationToken cancellationToken, DetailLevel detailLevel) => Task.FromResult(azureProxyReturnValues.GetBatchPoolImpl(id))); azureProxy.Setup(a => a.GetFullAllocationStateAsync(It.IsAny(), It.IsAny())).Returns((string poolId, System.Threading.CancellationToken _1) => Task.FromResult(GetPoolStateFromSettingStateOrDefault(poolId))); - azureProxy.Setup(a => a.DeleteBatchPoolAsync(It.IsAny(), It.IsAny())).Callback((poolId, cancellationToken) => azureProxyReturnValues.AzureProxyDeleteBatchPoolImpl(poolId, cancellationToken)).Returns(Task.CompletedTask); FullBatchPoolAllocationState GetPoolStateFromSettingStateOrDefault(string poolId) { @@ -341,17 +355,12 @@ FullBatchPoolAllocationState GetPoolStateFromSettingStateOrDefault(string poolId private static IEnumerable<(string Key, string Value)> GetMockConfig() => Enumerable .Empty<(string Key, string Value)>() - .Append(("BatchScheduling:PoolRotationForcedDays", "0.000694444")); + .Append(("BatchScheduling:PoolRotationForcedDays", "0.000694444")) + .Append(("BatchScheduling:Prefix", "0123456789")); - private sealed class MockServiceClient : Microsoft.Azure.Batch.Protocol.BatchServiceClient + private sealed class MockServiceClient(Microsoft.Azure.Batch.Protocol.IComputeNodeOperations computeNode) : Microsoft.Azure.Batch.Protocol.BatchServiceClient { - private readonly Microsoft.Azure.Batch.Protocol.IComputeNodeOperations computeNode; - - public MockServiceClient(Microsoft.Azure.Batch.Protocol.IComputeNodeOperations computeNode) - { - this.computeNode = computeNode ?? throw new ArgumentNullException(nameof(computeNode)); - } - + private readonly Microsoft.Azure.Batch.Protocol.IComputeNodeOperations computeNode = computeNode ?? throw new ArgumentNullException(nameof(computeNode)); public override Microsoft.Azure.Batch.Protocol.IComputeNodeOperations ComputeNode => computeNode; } @@ -366,7 +375,7 @@ internal static CloudPool GeneratePool( Microsoft.Azure.Batch.Protocol.Models.AutoScaleRun autoScaleRun = default, bool? enableAutoScale = default, DateTime? creationTime = default, - IList metadata = default) + IList metadata = default) { if (default == creationTime) { @@ -375,13 +384,13 @@ internal static CloudPool GeneratePool( metadata ??= []; - var computeNodeOperations = new Mock(); - var batchServiceClient = new MockServiceClient(computeNodeOperations.Object); - var protocolLayer = typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient).Assembly.GetType("Microsoft.Azure.Batch.ProtocolLayer").GetConstructor(System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic, null, new Type[] { typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient) }, null) - .Invoke(new object[] { batchServiceClient }); - var parentClient = (BatchClient)typeof(BatchClient).GetConstructor(System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic, null, new Type[] { typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient).Assembly.GetType("Microsoft.Azure.Batch.IProtocolLayer") }, null) - .Invoke(new object[] { protocolLayer }); - var modelPool = new Microsoft.Azure.Batch.Protocol.Models.CloudPool( + Mock computeNodeOperations = new(); + MockServiceClient batchServiceClient = new(computeNodeOperations.Object); + var protocolLayer = typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient).Assembly.GetType("Microsoft.Azure.Batch.ProtocolLayer").GetConstructor(System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic, null, [typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient)], null) + .Invoke([batchServiceClient]); + var parentClient = (BatchClient)typeof(BatchClient).GetConstructor(System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic, null, [typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient).Assembly.GetType("Microsoft.Azure.Batch.IProtocolLayer")], null) + .Invoke([protocolLayer]); + Microsoft.Azure.Batch.Protocol.Models.CloudPool modelPool = new( id: id, currentDedicatedNodes: currentDedicatedNodes, currentLowPriorityNodes: currentLowPriorityNodes, @@ -392,11 +401,11 @@ internal static CloudPool GeneratePool( enableAutoScale: enableAutoScale, creationTime: creationTime, metadata: metadata.Select(ConvertMetadata).ToList()); - var pool = (CloudPool)typeof(CloudPool).GetConstructor(System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance, default, new Type[] { typeof(BatchClient), typeof(Microsoft.Azure.Batch.Protocol.Models.CloudPool), typeof(IEnumerable) }, default) + var pool = (CloudPool)typeof(CloudPool).GetConstructor(System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance, default, [typeof(BatchClient), typeof(Microsoft.Azure.Batch.Protocol.Models.CloudPool), typeof(IEnumerable)], default) .Invoke([parentClient, modelPool, null]); return pool; - static Microsoft.Azure.Batch.Protocol.Models.MetadataItem ConvertMetadata(Microsoft.Azure.Batch.MetadataItem item) + static Microsoft.Azure.Batch.Protocol.Models.MetadataItem ConvertMetadata(MetadataItem item) => item is null ? default : new(item.Name, item.Value); } @@ -407,15 +416,15 @@ internal static CloudTask GenerateTask(string jobId, string id, DateTime stateTr stateTransitionTime = DateTime.UtcNow; } - var computeNodeOperations = new Mock(); - var batchServiceClient = new MockServiceClient(computeNodeOperations.Object); - var protocolLayer = typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient).Assembly.GetType("Microsoft.Azure.Batch.ProtocolLayer").GetConstructor(System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic, null, new Type[] { typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient) }, null) - .Invoke(new object[] { batchServiceClient }); - var parentClient = (BatchClient)typeof(BatchClient).GetConstructor(System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic, null, new Type[] { typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient).Assembly.GetType("Microsoft.Azure.Batch.IProtocolLayer") }, null) - .Invoke(new object[] { protocolLayer }); - var modelTask = new Microsoft.Azure.Batch.Protocol.Models.CloudTask(id: id, stateTransitionTime: stateTransitionTime, state: Microsoft.Azure.Batch.Protocol.Models.TaskState.Active); - var task = (CloudTask)typeof(CloudTask).GetConstructor(System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance, default, new Type[] { typeof(BatchClient), typeof(string), typeof(Microsoft.Azure.Batch.Protocol.Models.CloudTask), typeof(IEnumerable) }, default) - .Invoke(new object[] { parentClient, jobId, modelTask, Enumerable.Empty() }); + Mock computeNodeOperations = new(); + MockServiceClient batchServiceClient = new(computeNodeOperations.Object); + var protocolLayer = typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient).Assembly.GetType("Microsoft.Azure.Batch.ProtocolLayer").GetConstructor(System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic, null, [typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient)], null) + .Invoke([batchServiceClient]); + var parentClient = (BatchClient)typeof(BatchClient).GetConstructor(System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic, null, [typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient).Assembly.GetType("Microsoft.Azure.Batch.IProtocolLayer")], null) + .Invoke([protocolLayer]); + Microsoft.Azure.Batch.Protocol.Models.CloudTask modelTask = new(id: id, stateTransitionTime: stateTransitionTime, state: Microsoft.Azure.Batch.Protocol.Models.TaskState.Active); + var task = (CloudTask)typeof(CloudTask).GetConstructor(System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance, default, [typeof(BatchClient), typeof(string), typeof(Microsoft.Azure.Batch.Protocol.Models.CloudTask), typeof(IEnumerable)], default) + .Invoke([parentClient, jobId, modelTask, Enumerable.Empty()]); return task; } @@ -426,15 +435,15 @@ internal static ComputeNode GenerateNode(string poolId, string id, bool isDedica stateTransitionTime = DateTime.UtcNow; } - var computeNodeOperations = new Mock(); - var batchServiceClient = new MockServiceClient(computeNodeOperations.Object); - var protocolLayer = typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient).Assembly.GetType("Microsoft.Azure.Batch.ProtocolLayer").GetConstructor(System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic, null, new Type[] { typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient) }, null) - .Invoke(new object[] { batchServiceClient }); - var parentClient = (BatchClient)typeof(BatchClient).GetConstructor(System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic, null, new Type[] { typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient).Assembly.GetType("Microsoft.Azure.Batch.IProtocolLayer") }, null) - .Invoke(new object[] { protocolLayer }); - var modelNode = new Microsoft.Azure.Batch.Protocol.Models.ComputeNode(stateTransitionTime: stateTransitionTime, id: id, affinityId: AffinityPrefix + id, isDedicated: isDedicated, state: isIdle ? Microsoft.Azure.Batch.Protocol.Models.ComputeNodeState.Idle : Microsoft.Azure.Batch.Protocol.Models.ComputeNodeState.Running); - var node = (ComputeNode)typeof(ComputeNode).GetConstructor(System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance, default, new Type[] { typeof(BatchClient), typeof(string), typeof(Microsoft.Azure.Batch.Protocol.Models.ComputeNode), typeof(IEnumerable) }, default) - .Invoke(new object[] { parentClient, poolId, modelNode, null }); + Mock computeNodeOperations = new(); + MockServiceClient batchServiceClient = new(computeNodeOperations.Object); + var protocolLayer = typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient).Assembly.GetType("Microsoft.Azure.Batch.ProtocolLayer").GetConstructor(System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic, null, [typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient)], null) + .Invoke([batchServiceClient]); + var parentClient = (BatchClient)typeof(BatchClient).GetConstructor(System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic, null, [typeof(Microsoft.Azure.Batch.Protocol.BatchServiceClient).Assembly.GetType("Microsoft.Azure.Batch.IProtocolLayer")], null) + .Invoke([protocolLayer]); + Microsoft.Azure.Batch.Protocol.Models.ComputeNode modelNode = new(stateTransitionTime: stateTransitionTime, id: id, affinityId: AffinityPrefix + id, isDedicated: isDedicated, state: isIdle ? Microsoft.Azure.Batch.Protocol.Models.ComputeNodeState.Idle : Microsoft.Azure.Batch.Protocol.Models.ComputeNodeState.Running); + var node = (ComputeNode)typeof(ComputeNode).GetConstructor(System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance, default, [typeof(BatchClient), typeof(string), typeof(Microsoft.Azure.Batch.Protocol.Models.ComputeNode), typeof(IEnumerable)], default) + .Invoke([parentClient, poolId, modelNode, null]); return node; } } diff --git a/src/TesApi.Tests/BatchSchedulerTests.cs b/src/TesApi.Tests/BatchSchedulerTests.cs index 41ce07813..30280e3da 100644 --- a/src/TesApi.Tests/BatchSchedulerTests.cs +++ b/src/TesApi.Tests/BatchSchedulerTests.cs @@ -9,13 +9,15 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Azure.ResourceManager.Batch; +using Azure.ResourceManager.Batch.Models; +using Azure.Storage.Blobs; +using Azure.Storage.Blobs.Models; using Microsoft.Azure.Batch; using Microsoft.Azure.Batch.Common; -using Microsoft.Azure.Management.Batch.Models; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.VisualStudio.TestTools.UnitTesting; -using Microsoft.WindowsAzure.Storage.Blob; using Moq; using Newtonsoft.Json; using Tes.Extensions; @@ -23,6 +25,7 @@ using Tes.TaskSubmitters; using TesApi.Web; using TesApi.Web.Management; +using TesApi.Web.Management.Batch; using TesApi.Web.Management.Models.Quotas; using TesApi.Web.Storage; @@ -43,7 +46,7 @@ public async Task LocalPoolCacheAccessesNewPoolsAfterAllPoolsRemovedWithSameKey( Assert.IsTrue(batchScheduler.RemovePoolFromList(pool)); Assert.AreEqual(0, batchScheduler.GetPoolGroupKeys().Count()); - pool = (BatchPool)await batchScheduler.GetOrAddPoolAsync(key, false, (id, cancellationToken) => ValueTask.FromResult(new Pool(name: id)), System.Threading.CancellationToken.None); + pool = (BatchPool)await batchScheduler.GetOrAddPoolAsync(key, false, (id, cancellationToken) => ValueTask.FromResult(BatchPoolTests.CreatePoolData(name: id)), CancellationToken.None); Assert.IsNotNull(pool); Assert.AreEqual(1, batchScheduler.GetPoolGroupKeys().Count()); @@ -61,16 +64,16 @@ public async Task GetOrAddDoesNotAddExistingAvailablePool() var keyCount = batchScheduler.GetPoolGroupKeys().Count(); var key = batchScheduler.GetPoolGroupKeys().First(); var count = batchScheduler.GetPools().Count(); - serviceProvider.AzureProxy.Verify(mock => mock.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Once); + serviceProvider.BatchPoolManager.Verify(mock => mock.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Once); - var pool = await batchScheduler.GetOrAddPoolAsync(key, false, (id, cancellationToken) => ValueTask.FromResult(new Pool(name: id)), System.Threading.CancellationToken.None); + var pool = await batchScheduler.GetOrAddPoolAsync(key, false, (id, cancellationToken) => ValueTask.FromResult(BatchPoolTests.CreatePoolData(name: id)), CancellationToken.None); await pool.ServicePoolAsync(); Assert.AreEqual(count, batchScheduler.GetPools().Count()); Assert.AreEqual(keyCount, batchScheduler.GetPoolGroupKeys().Count()); //Assert.AreSame(info, pool); Assert.AreEqual(info.PoolId, pool.PoolId); - serviceProvider.AzureProxy.Verify(mock => mock.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Once); + serviceProvider.BatchPoolManager.Verify(mock => mock.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny()), Times.Once); } [TestCategory("Batch Pools")] @@ -86,7 +89,7 @@ public async Task GetOrAddDoesAddWithExistingUnavailablePool() var key = batchScheduler.GetPoolGroupKeys().First(); var count = batchScheduler.GetPools().Count(); - var pool = await batchScheduler.GetOrAddPoolAsync(key, false, (id, cancellationToken) => ValueTask.FromResult(new Pool(name: id)), System.Threading.CancellationToken.None); + var pool = await batchScheduler.GetOrAddPoolAsync(key, false, (id, cancellationToken) => ValueTask.FromResult(BatchPoolTests.CreatePoolData(name: id)), CancellationToken.None); await pool.ServicePoolAsync(); Assert.AreNotEqual(count, batchScheduler.GetPools().Count()); @@ -221,15 +224,17 @@ public async Task TestIfVmSizeIsAvailable(string vmSize, bool preemptible) task.Resources.BackendParameters = new() { { "vm_size", vmSize } }; var config = GetMockConfig()(); + var azureProxyReturnValues = AzureProxyReturnValues.Defaults; using var serviceProvider = GetServiceProvider( config, - GetMockAzureProxy(AzureProxyReturnValues.Defaults), - GetMockQuotaProvider(AzureProxyReturnValues.Defaults), - GetMockSkuInfoProvider(AzureProxyReturnValues.Defaults), + GetMockAzureProxy(azureProxyReturnValues), + GetMockBatchPoolManager(azureProxyReturnValues), + GetMockQuotaProvider(azureProxyReturnValues), + GetMockSkuInfoProvider(azureProxyReturnValues), GetMockAllowedVms(config)); var batchScheduler = serviceProvider.GetT(); - var size = await ((BatchScheduler)batchScheduler).GetVmSizeAsync(task, System.Threading.CancellationToken.None); + var size = await ((BatchScheduler)batchScheduler).GetVmSizeAsync(task, CancellationToken.None); GuardAssertsWithTesTask(task, () => Assert.AreEqual(vmSize, size.VmSize)); } @@ -241,9 +246,11 @@ public async Task TesTaskFailsWithSystemErrorWhenNoSuitableVmExists() { var azureProxyReturnValues = AzureProxyReturnValues.Defaults; - azureProxyReturnValues.VmSizesAndPrices = new() { + azureProxyReturnValues.VmSizesAndPrices = + [ new() { VmSize = "VmSize1", LowPriority = true, VCpusAvailable = 1, MemoryInGiB = 4, ResourceDiskSizeInGiB = 20, PricePerHour = 1 }, - new() { VmSize = "VmSize2", LowPriority = true, VCpusAvailable = 2, MemoryInGiB = 8, ResourceDiskSizeInGiB = 40, PricePerHour = 2 }}; + new() { VmSize = "VmSize2", LowPriority = true, VCpusAvailable = 2, MemoryInGiB = 8, ResourceDiskSizeInGiB = 40, PricePerHour = 2 } + ]; Assert.AreEqual(TesState.SYSTEM_ERROR, await GetNewTesTaskStateAsync(new TesResources { CpuCores = 1, RamGb = 1, DiskGb = 10, Preemptible = false }, azureProxyReturnValues)); Assert.AreEqual(TesState.SYSTEM_ERROR, await GetNewTesTaskStateAsync(new TesResources { CpuCores = 4, RamGb = 1, DiskGb = 10, Preemptible = true }, azureProxyReturnValues)); @@ -260,7 +267,7 @@ public async Task TesTaskFailsWithSystemErrorWhenTotalBatchQuotaIsSetTooLow() Assert.AreEqual(TesState.SYSTEM_ERROR, await GetNewTesTaskStateAsync(new TesResources { CpuCores = 2, RamGb = 1, Preemptible = false }, azureProxyReturnValues)); Assert.AreEqual(TesState.SYSTEM_ERROR, await GetNewTesTaskStateAsync(new TesResources { CpuCores = 11, RamGb = 1, Preemptible = true }, azureProxyReturnValues)); - var dedicatedCoreQuotaPerVMFamily = new List { new("VmFamily2", 1) }; + List dedicatedCoreQuotaPerVMFamily = [CreateBatchVmFamilyCoreQuota("VmFamily2", 1)]; azureProxyReturnValues.BatchQuotas = new() { ActiveJobAndJobScheduleQuota = 1, @@ -290,16 +297,22 @@ public async Task TesTaskFailsWhenBatchNodeDiskIsFull() }); } - private async Task AddBatchTaskHandlesExceptions(TesState newState, Func, Action>)> testArranger, Action> resultValidator) + private async Task AddBatchTaskHandlesExceptions(TesState newState, Func, Action>, Action>)> testArranger, Action> resultValidator) { var logger = new Mock>(); var azureProxyReturnValues = AzureProxyReturnValues.Defaults; - var (providerModifier, azureProxyModifier) = testArranger?.Invoke(azureProxyReturnValues) ?? (default, default); + var (providerModifier, azureProxyModifier, batchPoolManagerModifier) = testArranger?.Invoke(azureProxyReturnValues) ?? (default, default, default); var azureProxy = new Action>(mock => { GetMockAzureProxy(azureProxyReturnValues)(mock); azureProxyModifier?.Invoke(mock); }); + var batchPoolManager = new Action>(mock => + { + GetMockBatchPoolManager(azureProxyReturnValues)(mock); + batchPoolManagerModifier?.Invoke(mock); + }); + var task = GetTesTask(); task.State = TesState.QUEUED; @@ -307,6 +320,7 @@ private async Task AddBatchTaskHandlesExceptions(TesState newState, Func { @@ -326,10 +340,12 @@ public Task AddBatchTaskHandlesAzureBatchPoolCreationExceptionViaJobCreation() { return AddBatchTaskHandlesExceptions(TesState.QUEUED, Arranger, Validator); - (Action, Action>) Arranger(AzureProxyReturnValues _1) - => (default, azureProxy => azureProxy.Setup(b => b.CreateBatchJobAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Callback((_, _, _) - => throw new Microsoft.Rest.Azure.CloudException("No job for you.") { Body = new() { Code = BatchErrorCodeStrings.OperationTimedOut } })); + (Action, Action>, Action>) Arranger(AzureProxyReturnValues _1) + => (default, + azureProxy => azureProxy.Setup(b => b.CreateBatchJobAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((_, _, _) + => throw new Microsoft.Rest.Azure.CloudException("No job for you.") { Body = new() { Code = BatchErrorCodeStrings.OperationTimedOut } }), + default); void Validator(TesTask tesTask, IEnumerable<(LogLevel logLevel, Exception exception)> logs) { @@ -349,10 +365,12 @@ public Task AddBatchTaskHandlesAzureBatchPoolCreationExceptionViaPoolCreation() { return AddBatchTaskHandlesExceptions(TesState.QUEUED, Arranger, Validator); - (Action, Action>) Arranger(AzureProxyReturnValues _1) - => (default, azureProxy => azureProxy.Setup(b => b.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Callback((poolInfo, isPreemptible, cancellationToken) - => throw new Microsoft.Rest.Azure.CloudException("No job for you.") { Body = new() { Code = BatchErrorCodeStrings.OperationTimedOut } })); + (Action, Action>, Action>) Arranger(AzureProxyReturnValues _1) + => (default, + default, + azureProxy => azureProxy.Setup(b => b.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((poolInfo, isPreemptible, cancellationToken) + => throw new Microsoft.Rest.Azure.CloudException("No pool for you.") { Body = new() { Code = BatchErrorCodeStrings.OperationTimedOut } })); void Validator(TesTask tesTask, IEnumerable<(LogLevel logLevel, Exception exception)> logs) { @@ -373,8 +391,10 @@ public Task AddBatchTaskHandlesAzureBatchQuotaMaxedOutException() var quotaVerifier = new Mock(); return AddBatchTaskHandlesExceptions(TesState.QUEUED, Arranger, Validator); - (Action, Action>) Arranger(AzureProxyReturnValues _1) - => (services => services.AddSingleton(), default); + (Action, Action>, Action>) Arranger(AzureProxyReturnValues _1) + => (services => services.AddSingleton(), + default, + default); void Validator(TesTask tesTask, IEnumerable<(LogLevel logLevel, Exception exception)> logs) { @@ -393,8 +413,10 @@ public Task AddBatchTaskHandlesAzureBatchLowQuotaException() var quotaVerifier = new Mock(); return AddBatchTaskHandlesExceptions(TesState.SYSTEM_ERROR, Arranger, Validator); - (Action, Action>) Arranger(AzureProxyReturnValues _1) - => (services => services.AddSingleton(), default); + (Action, Action>, Action>) Arranger(AzureProxyReturnValues _1) + => (services => services.AddSingleton(), + default, + default); void Validator(TesTask tesTask, IEnumerable<(LogLevel logLevel, Exception exception)> logs) { @@ -414,10 +436,10 @@ public Task AddBatchTaskHandlesAzureBatchVirtualMachineAvailabilityException() { return AddBatchTaskHandlesExceptions(TesState.SYSTEM_ERROR, Arranger, Validator); - (Action, Action>) Arranger(AzureProxyReturnValues proxy) + (Action, Action>, Action>) Arranger(AzureProxyReturnValues proxy) { proxy.VmSizesAndPrices = Enumerable.Empty().ToList(); - return (default, default); + return (default, default, default); } void Validator(TesTask tesTask, IEnumerable<(LogLevel logLevel, Exception exception)> logs) @@ -438,9 +460,11 @@ public Task AddBatchTaskHandlesTesException() { return AddBatchTaskHandlesExceptions(TesState.SYSTEM_ERROR, Arranger, Validator); - (Action, Action>) Arranger(AzureProxyReturnValues _1) - => (default, azureProxy => azureProxy.Setup(b => b.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Callback((poolInfo, isPreemptible, cancellationToken) + (Action, Action>, Action>) Arranger(AzureProxyReturnValues _1) + => (default, + default, + azureProxy => azureProxy.Setup(b => b.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((poolInfo, isPreemptible, cancellationToken) => throw new TesException("TestFailureReason"))); void Validator(TesTask tesTask, IEnumerable<(LogLevel logLevel, Exception exception)> logs) @@ -461,13 +485,15 @@ public Task AddBatchTaskHandlesBatchClientException() { return AddBatchTaskHandlesExceptions(TesState.SYSTEM_ERROR, Arranger, Validator); - (Action, Action>) Arranger(AzureProxyReturnValues _1) - => (default, azureProxy => azureProxy.Setup(b => b.AddBatchTaskAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Callback((_, _, _, _) - => throw typeof(BatchClientException) + (Action, Action>, Action>) Arranger(AzureProxyReturnValues _1) + => (default, + default, + azureProxy => azureProxy.Setup(b => b.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((poolInfo, isPreemptible, cancellationToken) + => throw typeof(BatchClientException) .GetConstructor(System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance, - new[] { typeof(string), typeof(Exception) }) - .Invoke(new object[] { null, null }) as Exception)); + [typeof(string), typeof(Exception)]) + .Invoke([null, null]) as Exception)); void Validator(TesTask tesTask, IEnumerable<(LogLevel logLevel, Exception exception)> logs) { @@ -487,13 +513,15 @@ public Task AddBatchTaskHandlesBatchExceptionForJobQuota() { return AddBatchTaskHandlesExceptions(TesState.QUEUED, Arranger, Validator); - (Action, Action>) Arranger(AzureProxyReturnValues _1) - => (default, azureProxy => azureProxy.Setup(b => b.CreateBatchJobAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Callback((_, _, _) - => throw new BatchException( - new Mock().Object, - default, - new Microsoft.Azure.Batch.Protocol.Models.BatchErrorException() { Body = new() { Code = "ActiveJobAndScheduleQuotaReached", Message = new(value: "No job for you.") } }))); + (Action, Action>, Action>) Arranger(AzureProxyReturnValues _1) + => (default, + default, + azureProxy => azureProxy.Setup(b => b.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((poolInfo, isPreemptible, cancellationToken) + => throw new BatchException( + new Mock().Object, + default, + new Microsoft.Azure.Batch.Protocol.Models.BatchErrorException() { Body = new() { Code = "ActiveJobAndScheduleQuotaReached", Message = new(value: "No job for you.") } }))); void Validator(TesTask task, IEnumerable<(LogLevel logLevel, Exception exception)> logs) { @@ -512,13 +540,15 @@ public Task AddBatchTaskHandlesBatchExceptionForPoolQuota() { return AddBatchTaskHandlesExceptions(TesState.QUEUED, Arranger, Validator); - (Action, Action>) Arranger(AzureProxyReturnValues _1) - => (default, azureProxy => azureProxy.Setup(b => b.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Callback((poolInfo, isPreemptible, cancellationToken) - => throw new BatchException( - new Mock().Object, - default, - new Microsoft.Azure.Batch.Protocol.Models.BatchErrorException() { Body = new() { Code = "PoolQuotaReached", Message = new(value: "No pool for you.") } }))); + (Action, Action>, Action>) Arranger(AzureProxyReturnValues _1) + => (default, + default, + azureProxy => azureProxy.Setup(b => b.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((poolInfo, isPreemptible, cancellationToken) + => throw new BatchException( + new Mock().Object, + default, + new Microsoft.Azure.Batch.Protocol.Models.BatchErrorException() { Body = new() { Code = "PoolQuotaReached", Message = new(value: "No pool for you.") } }))); void Validator(TesTask task, IEnumerable<(LogLevel logLevel, Exception exception)> logs) { @@ -537,10 +567,12 @@ public Task AddBatchTaskHandlesCloudExceptionForPoolQuota() { return AddBatchTaskHandlesExceptions(TesState.QUEUED, Arranger, Validator); - (Action, Action>) Arranger(AzureProxyReturnValues _1) - => (default, azureProxy => azureProxy.Setup(b => b.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Callback((poolInfo, isPreemptible, cancellationToken) - => throw new Microsoft.Rest.Azure.CloudException() { Body = new() { Code = "AutoPoolCreationFailedWithQuotaReached", Message = "No autopool for you." } })); + (Action, Action>, Action>) Arranger(AzureProxyReturnValues _1) + => (default, + default, + azureProxy => azureProxy.Setup(b => b.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((poolInfo, isPreemptible, cancellationToken) + => throw new Microsoft.Rest.Azure.CloudException() { Body = new() { Code = "AutoPoolCreationFailedWithQuotaReached", Message = "No autopool for you." } })); void Validator(TesTask task, IEnumerable<(LogLevel logLevel, Exception exception)> logs) { @@ -559,11 +591,11 @@ public Task AddBatchTaskHandlesUnknownException() { var exceptionMsg = "Successful Test"; var batchQuotaProvider = new Mock(); - batchQuotaProvider.Setup(p => p.GetVmCoreQuotaAsync(It.IsAny(), It.IsAny())).Callback((lowPriority, _1) => throw new InvalidOperationException(exceptionMsg)); + batchQuotaProvider.Setup(p => p.GetVmCoreQuotaAsync(It.IsAny(), It.IsAny())).Callback((lowPriority, _1) => throw new InvalidOperationException(exceptionMsg)); return AddBatchTaskHandlesExceptions(TesState.SYSTEM_ERROR, Arranger, Validator); - (Action, Action>) Arranger(AzureProxyReturnValues _1) - => (services => services.AddTransient(p => batchQuotaProvider.Object), default); + (Action, Action>, Action>) Arranger(AzureProxyReturnValues _1) + => (services => services.AddTransient(p => batchQuotaProvider.Object), default, default); void Validator(TesTask tesTask, IEnumerable<(LogLevel logLevel, Exception exception)> logs) { @@ -585,24 +617,26 @@ public async Task BatchJobContainsExpectedBatchPoolInformation() { var tesTask = GetTesTask(); var config = GetMockConfig()(); + var azureProxyReturnValues = AzureProxyReturnValues.Defaults; using var serviceProvider = GetServiceProvider( config, - GetMockAzureProxy(AzureProxyReturnValues.Defaults), - GetMockQuotaProvider(AzureProxyReturnValues.Defaults), - GetMockSkuInfoProvider(AzureProxyReturnValues.Defaults), + GetMockAzureProxy(azureProxyReturnValues), + GetMockBatchPoolManager(azureProxyReturnValues), + GetMockQuotaProvider(azureProxyReturnValues), + GetMockSkuInfoProvider(azureProxyReturnValues), GetMockAllowedVms(config)); var batchScheduler = serviceProvider.GetT(); - await batchScheduler.ProcessTesTaskAsync(tesTask, System.Threading.CancellationToken.None); + await batchScheduler.ProcessTesTaskAsync(tesTask, CancellationToken.None); - var createBatchPoolAsyncInvocation = serviceProvider.AzureProxy.Invocations.FirstOrDefault(i => i.Method.Name == nameof(IAzureProxy.CreateBatchPoolAsync)); - var pool = createBatchPoolAsyncInvocation?.Arguments[0] as Pool; + var createBatchPoolAsyncInvocation = serviceProvider.BatchPoolManager.Invocations.FirstOrDefault(i => i.Method.Name == nameof(IBatchPoolManager.CreateBatchPoolAsync)); + var pool = createBatchPoolAsyncInvocation?.Arguments[0] as BatchAccountPoolData; GuardAssertsWithTesTask(tesTask, () => { - Assert.AreEqual("TES-hostname-edicated1-rpsd645merzfkqmdnj7pkqrase2ancnh-", pool.Name[0..^8]); + Assert.AreEqual("TES-hostname-edicated1-rpsd645merzfkqmdnj7pkqrase2ancnh-", tesTask.PoolId[0..^8]); Assert.AreEqual("VmSizeDedicated1", pool.VmSize); - Assert.IsTrue(((BatchScheduler)batchScheduler).TryGetPool(pool.Name, out _)); + Assert.IsTrue(((BatchScheduler)batchScheduler).TryGetPool(tesTask.PoolId, out _)); }); } @@ -616,8 +650,9 @@ public async Task BatchJobContainsExpectedManualPoolInformation() { { "workflow_execution_identity", identity } }; + var azureProxyReturnValues = AzureProxyReturnValues.Defaults; - (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(task, GetMockConfig()(), GetMockAzureProxy(AzureProxyReturnValues.Defaults), AzureProxyReturnValues.Defaults); + (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(task, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); GuardAssertsWithTesTask(task, () => { @@ -632,8 +667,9 @@ public async Task BatchJobContainsExpectedManualPoolInformation() public async Task NewTesTaskGetsScheduledSuccessfully() { var tesTask = GetTesTask(); + var azureProxyReturnValues = AzureProxyReturnValues.Defaults; - _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(AzureProxyReturnValues.Defaults), AzureProxyReturnValues.Defaults); + _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); GuardAssertsWithTesTask(tesTask, () => Assert.AreEqual(TesState.INITIALIZING, tesTask.State)); } @@ -643,8 +679,9 @@ public async Task PreemptibleTesTaskGetsScheduledToLowPriorityVm() { var tesTask = GetTesTask(); tesTask.Resources.Preemptible = true; + var azureProxyReturnValues = AzureProxyReturnValues.Defaults; - (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(AzureProxyReturnValues.Defaults), AzureProxyReturnValues.Defaults); + (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); GuardAssertsWithTesTask(tesTask, () => { @@ -659,8 +696,9 @@ public async Task NonPreemptibleTesTaskGetsScheduledToDedicatedVm() { var tesTask = GetTesTask(); tesTask.Resources.Preemptible = false; + var azureProxyReturnValues = AzureProxyReturnValues.Defaults; - (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(AzureProxyReturnValues.Defaults), AzureProxyReturnValues.Defaults); + (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); GuardAssertsWithTesTask(tesTask, () => { @@ -675,8 +713,9 @@ public async Task PreemptibleTesTaskGetsScheduledToLowPriorityVm_PerVMFamilyEnfo { var tesTask = GetTesTask(); tesTask.Resources.Preemptible = true; + var azureProxyReturnValues = AzureProxyReturnValues.DefaultsPerVMFamilyEnforced; - (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(AzureProxyReturnValues.DefaultsPerVMFamilyEnforced), AzureProxyReturnValues.DefaultsPerVMFamilyEnforced); + (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); GuardAssertsWithTesTask(tesTask, () => { @@ -691,8 +730,9 @@ public async Task NonPreemptibleTesTaskGetsScheduledToDedicatedVm_PerVMFamilyEnf { var tesTask = GetTesTask(); tesTask.Resources.Preemptible = false; + var azureProxyReturnValues = AzureProxyReturnValues.DefaultsPerVMFamilyEnforced; - (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(AzureProxyReturnValues.DefaultsPerVMFamilyEnforced), AzureProxyReturnValues.DefaultsPerVMFamilyEnforced); + (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); GuardAssertsWithTesTask(tesTask, () => { @@ -712,7 +752,7 @@ public async Task NonPreemptibleTesTaskGetsWarningAndIsScheduledToLowPriorityVmI var azureProxyReturnValues = AzureProxyReturnValues.DefaultsPerVMFamilyEnforced; azureProxyReturnValues.VmSizesAndPrices.First(vm => vm.VmSize.Equals("VmSize3", StringComparison.OrdinalIgnoreCase)).PricePerHour = 44; - (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), azureProxyReturnValues); + (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); GuardAssertsWithTesTask(tesTask, () => { @@ -729,8 +769,9 @@ public async Task TesTaskGetsScheduledToLowPriorityVmIfSettingUsePreemptibleVmsO var config = GetMockConfig()() .Append(("BatchScheduling:UsePreemptibleVmsOnly", "true")); + var azureProxyReturnValues = AzureProxyReturnValues.Defaults; - (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, config, GetMockAzureProxy(AzureProxyReturnValues.Defaults), AzureProxyReturnValues.Defaults); + (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, config, GetMockAzureProxy(azureProxyReturnValues), GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); GuardAssertsWithTesTask(tesTask, () => Assert.IsTrue(poolSpec.ScaleSettings.AutoScale.Formula.Contains("$TargetLowPriorityNodes"))); } @@ -745,8 +786,9 @@ static async Task RunTest(string allowedVmSizes, TesState expectedTaskState, str var config = GetMockConfig()() .Append(("AllowedVmSizes", allowedVmSizes)); + var azureProxyReturnValues = AzureProxyReturnValues.Defaults; - (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, config, GetMockAzureProxy(AzureProxyReturnValues.Defaults), AzureProxyReturnValues.Defaults); + (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, config, GetMockAzureProxy(azureProxyReturnValues), GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); GuardAssertsWithTesTask(tesTask, () => { @@ -881,13 +923,13 @@ public async Task TaskGetsCancelled() azureProxy = mock; }); - _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), azureProxySetter, azureProxyReturnValues); + _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), azureProxySetter, GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); GuardAssertsWithTesTask(tesTask, () => { Assert.AreEqual(TesState.CANCELED, tesTask.State); Assert.IsFalse(tesTask.IsCancelRequested); - azureProxy.Verify(i => i.DeleteBatchTaskAsync(tesTask.Id, It.IsAny(), It.IsAny())); + azureProxy.Verify(i => i.DeleteBatchTaskAsync(tesTask.Id, It.IsAny(), It.IsAny())); }); } @@ -917,7 +959,7 @@ public async Task SuccessfullyCompletedTaskContainsBatchNodeMetrics() azureProxyReturnValues.BatchJobAndTaskState = BatchJobAndTaskStates.TaskCompletedSuccessfully; azureProxyReturnValues.DownloadedBlobContent = metricsFileContent; - _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), azureProxyReturnValues); + _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); GuardAssertsWithTesTask(tesTask, () => { @@ -953,8 +995,9 @@ public async Task SuccessfullyCompletedTaskContainsCromwellResultCode() azureProxyReturnValues.BatchJobAndTaskState = BatchJobAndTaskStates.TaskCompletedSuccessfully; azureProxyReturnValues.DownloadedBlobContent = "2"; var azureProxy = GetMockAzureProxy(azureProxyReturnValues); + var batchPoolManager = GetMockBatchPoolManager(azureProxyReturnValues); - _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), azureProxy, azureProxyReturnValues); + _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), azureProxy, batchPoolManager, azureProxyReturnValues); GuardAssertsWithTesTask(tesTask, () => { @@ -1003,14 +1046,14 @@ public async Task CromwellWriteFilesAreDiscoveredAndAddedIfMissedWithContentScri Uri executionDirectoryUri = default; - _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), azureProxySetter, azureProxyReturnValues, serviceProviderActions: serviceProvider => + _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), azureProxySetter, GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues, serviceProviderActions: serviceProvider => { var storageAccessProvider = serviceProvider.GetServiceOrCreateInstance(); var commandScriptDir = new UriBuilder(commandScriptUri) { Path = Path.GetDirectoryName(commandScriptUri.AbsolutePath).Replace('\\', '/') }.Uri; executionDirectoryUri = UrlMutableSASEqualityComparer.TrimUri(storageAccessProvider.MapLocalPathToSasUrlAsync(commandScriptDir.IsFile ? commandScriptDir.AbsolutePath : commandScriptDir.AbsoluteUri, CancellationToken.None, getContainerSas: true).Result); - serviceProvider.AzureProxy.Setup(p => p.ListBlobsAsync(It.Is(executionDirectoryUri, new UrlMutableSASEqualityComparer()), It.IsAny())).Returns(Task.FromResult>(executionDirectoryBlobs)); + serviceProvider.AzureProxy.Setup(p => p.ListBlobsAsync(It.Is(executionDirectoryUri, new UrlMutableSASEqualityComparer()), It.IsAny())).Returns(Task.FromResult>(executionDirectoryBlobs)); var uri = new UriBuilder(executionDirectoryUri); uri.Path = uri.Path.TrimEnd('/') + $"/{fileName}"; @@ -1033,8 +1076,11 @@ public async Task CromwellWriteFilesAreDiscoveredAndAddedIfMissedWithContentScri Assert.AreEqual(2, filesToDownload.Length); }); - static CloudBlob CloudBlobFromTesInput(TesInput input) - => new(UriFromTesInput(input)); + static BlobItem CloudBlobFromTesInput(TesInput input) + { + BlobUriBuilder builder = new(UriFromTesInput(input)); + return BlobsModelFactory.BlobItem(name: builder.BlobName); + } static Uri UriFromTesInput(TesInput input) { @@ -1092,15 +1138,17 @@ public async Task PoolIsCreatedInSubnetWhenBatchNodesSubnetIdIsSet() .Append(("BatchNodes:SubnetId", "subnet1")); var tesTask = GetTesTask(); - var azureProxy = GetMockAzureProxy(AzureProxyReturnValues.Defaults); + var azureProxyReturnValues = AzureProxyReturnValues.Defaults; + var azureProxy = GetMockAzureProxy(azureProxyReturnValues); + var batchPoolManager = GetMockBatchPoolManager(azureProxyReturnValues); - (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, config, azureProxy, AzureProxyReturnValues.Defaults); + (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, config, azureProxy, batchPoolManager, azureProxyReturnValues); var poolNetworkConfiguration = poolSpec.NetworkConfiguration; GuardAssertsWithTesTask(tesTask, () => { - Assert.AreEqual(Microsoft.Azure.Management.Batch.Models.IPAddressProvisioningType.BatchManaged, poolNetworkConfiguration?.PublicIPAddressConfiguration?.Provision); + Assert.AreEqual(BatchIPAddressProvisioningType.BatchManaged, poolNetworkConfiguration?.PublicIPAddressConfiguration?.Provision); Assert.AreEqual("subnet1", poolNetworkConfiguration?.SubnetId); }); } @@ -1113,15 +1161,17 @@ public async Task PoolIsCreatedWithoutPublicIpWhenSubnetAndDisableBatchNodesPubl .Append(("BatchNodes:DisablePublicIpAddress", "true")); var tesTask = GetTesTask(); - var azureProxy = GetMockAzureProxy(AzureProxyReturnValues.Defaults); + var azureProxyReturnValues = AzureProxyReturnValues.Defaults; + var azureProxy = GetMockAzureProxy(azureProxyReturnValues); + var batchPoolManager = GetMockBatchPoolManager(azureProxyReturnValues); - (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, config, azureProxy, AzureProxyReturnValues.Defaults); + (_, _, var poolSpec) = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, config, azureProxy, batchPoolManager, azureProxyReturnValues); var poolNetworkConfiguration = poolSpec.NetworkConfiguration; GuardAssertsWithTesTask(tesTask, () => { - Assert.AreEqual(Microsoft.Azure.Management.Batch.Models.IPAddressProvisioningType.NoPublicIPAddresses, poolNetworkConfiguration?.PublicIPAddressConfiguration?.Provision); + Assert.AreEqual(BatchIPAddressProvisioningType.NoPublicIPAddresses, poolNetworkConfiguration?.PublicIPAddressConfiguration?.Provision); Assert.AreEqual("subnet1", poolNetworkConfiguration?.SubnetId); }); } @@ -1131,16 +1181,17 @@ public async Task PoolIsCreatedWithoutPublicIpWhenSubnetAndDisableBatchNodesPubl var azureProxyReturnValues = AzureProxyReturnValues.Defaults; azureProxyReturnValues.BatchJobAndTaskState = azureBatchJobAndTaskState ?? azureProxyReturnValues.BatchJobAndTaskState; - _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), azureProxyReturnValues); + _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); return (tesTask.Logs?.LastOrDefault()?.FailureReason, tesTask.Logs?.LastOrDefault()?.SystemLogs?.ToArray()); } - private static async Task<(string JobId, CloudTask CloudTask, Pool batchModelsPool)> ProcessTesTaskAndGetBatchJobArgumentsAsync(TesTask tesTask, IEnumerable<(string Key, string Value)> configuration, Action> azureProxy, AzureProxyReturnValues azureProxyReturnValues, Action additionalActions = default, Action> serviceProviderActions = default) + private static async Task<(string JobId, CloudTask CloudTask, BatchAccountPoolData batchModelsPool)> ProcessTesTaskAndGetBatchJobArgumentsAsync(TesTask tesTask, IEnumerable<(string Key, string Value)> configuration, Action> azureProxy, Action> batchPoolManager, AzureProxyReturnValues azureProxyReturnValues, Action additionalActions = default, Action> serviceProviderActions = default) { using var serviceProvider = GetServiceProvider( configuration, azureProxy, + batchPoolManager, GetMockQuotaProvider(azureProxyReturnValues), GetMockSkuInfoProvider(azureProxyReturnValues), GetMockAllowedVms(configuration), @@ -1148,14 +1199,14 @@ public async Task PoolIsCreatedWithoutPublicIpWhenSubnetAndDisableBatchNodesPubl var batchScheduler = serviceProvider.GetT(); serviceProviderActions?.Invoke(serviceProvider); - await batchScheduler.ProcessTesTaskAsync(tesTask, System.Threading.CancellationToken.None); + await batchScheduler.ProcessTesTaskAsync(tesTask, CancellationToken.None); - var createBatchPoolAsyncInvocation = serviceProvider.AzureProxy.Invocations.FirstOrDefault(i => i.Method.Name == nameof(IAzureProxy.CreateBatchPoolAsync)); + var createBatchPoolAsyncInvocation = serviceProvider.BatchPoolManager.Invocations.FirstOrDefault(i => i.Method.Name == nameof(IBatchPoolManager.CreateBatchPoolAsync)); var addBatchTaskAsyncInvocation = serviceProvider.AzureProxy.Invocations.FirstOrDefault(i => i.Method.Name == nameof(IAzureProxy.AddBatchTaskAsync)); var jobId = addBatchTaskAsyncInvocation?.Arguments[2] as string; var cloudTask = addBatchTaskAsyncInvocation?.Arguments[1] as CloudTask; - var batchPoolsModel = createBatchPoolAsyncInvocation?.Arguments[0] as Pool; + var batchPoolsModel = createBatchPoolAsyncInvocation?.Arguments[0] as BatchAccountPoolData; return (jobId, cloudTask, batchPoolsModel); } @@ -1169,14 +1220,14 @@ private static Action> GetMockAllowedVms(IEnumerabl { allowedVms = allowedVmsConfig.Split(",").ToList(); } - proxy.Setup(p => p.GetAllowedVmSizes(It.IsAny())) + proxy.Setup(p => p.GetAllowedVmSizes(It.IsAny())) .ReturnsAsync(allowedVms); }); private static Action> GetMockSkuInfoProvider(AzureProxyReturnValues azureProxyReturnValues) => new(proxy => - proxy.Setup(p => p.GetVmSizesAndPricesAsync(It.IsAny(), It.IsAny())) + proxy.Setup(p => p.GetVmSizesAndPricesAsync(It.IsAny(), It.IsAny())) .ReturnsAsync(azureProxyReturnValues.VmSizesAndPrices)); private static Action> GetMockQuotaProvider(AzureProxyReturnValues azureProxyReturnValues) @@ -1186,14 +1237,14 @@ private static Action> GetMockQuotaProvider(AzureProxy var vmFamilyQuota = batchQuotas.DedicatedCoreQuotaPerVMFamily?.FirstOrDefault(v => string.Equals(v.Name, "VmFamily1", StringComparison.InvariantCultureIgnoreCase))?.CoreQuota ?? 0; quotaProvider.Setup(p => - p.GetQuotaForRequirementAsync(It.IsAny(), It.Is(p => p == false), It.IsAny(), It.IsAny())) + p.GetQuotaForRequirementAsync(It.IsAny(), It.Is(p => p == false), It.IsAny(), It.IsAny())) .ReturnsAsync(() => new BatchVmFamilyQuotas(batchQuotas.DedicatedCoreQuota, vmFamilyQuota, batchQuotas.PoolQuota, batchQuotas.ActiveJobAndJobScheduleQuota, batchQuotas.DedicatedCoreQuotaPerVMFamilyEnforced, "VmSize1")); quotaProvider.Setup(p => - p.GetQuotaForRequirementAsync(It.IsAny(), It.Is(p => p == true), It.IsAny(), It.IsAny())) + p.GetQuotaForRequirementAsync(It.IsAny(), It.Is(p => p == true), It.IsAny(), It.IsAny())) .ReturnsAsync(() => new BatchVmFamilyQuotas(batchQuotas.LowPriorityCoreQuota, vmFamilyQuota, batchQuotas.PoolQuota, @@ -1201,27 +1252,27 @@ private static Action> GetMockQuotaProvider(AzureProxy batchQuotas.DedicatedCoreQuotaPerVMFamilyEnforced, "VmSize1")); quotaProvider.Setup(p => - p.GetVmCoreQuotaAsync(It.Is(l => l == true), It.IsAny())) + p.GetVmCoreQuotaAsync(It.Is(l => l == true), It.IsAny())) .ReturnsAsync(new BatchVmCoreQuota(batchQuotas.LowPriorityCoreQuota, true, batchQuotas.DedicatedCoreQuotaPerVMFamilyEnforced, - batchQuotas.DedicatedCoreQuotaPerVMFamily?.Select(v => new BatchVmCoresPerFamily(v.Name, v.CoreQuota)).ToList(), + batchQuotas.DedicatedCoreQuotaPerVMFamily?.Select(v => new BatchVmCoresPerFamily(v.Name, v.CoreQuota ?? 0)).ToList(), new(batchQuotas.ActiveJobAndJobScheduleQuota, batchQuotas.PoolQuota, batchQuotas.DedicatedCoreQuota, batchQuotas.LowPriorityCoreQuota))); quotaProvider.Setup(p => - p.GetVmCoreQuotaAsync(It.Is(l => l == false), It.IsAny())) + p.GetVmCoreQuotaAsync(It.Is(l => l == false), It.IsAny())) .ReturnsAsync(new BatchVmCoreQuota(batchQuotas.DedicatedCoreQuota, false, batchQuotas.DedicatedCoreQuotaPerVMFamilyEnforced, - batchQuotas.DedicatedCoreQuotaPerVMFamily?.Select(v => new BatchVmCoresPerFamily(v.Name, v.CoreQuota)).ToList(), + batchQuotas.DedicatedCoreQuotaPerVMFamily?.Select(v => new BatchVmCoresPerFamily(v.Name, v.CoreQuota ?? 0)).ToList(), new(batchQuotas.ActiveJobAndJobScheduleQuota, batchQuotas.PoolQuota, batchQuotas.DedicatedCoreQuota, batchQuotas.LowPriorityCoreQuota))); }); - private static TestServices.TestServiceProvider GetServiceProvider(IEnumerable<(string Key, string Value)> configuration, Action> azureProxy, Action> quotaProvider, Action> skuInfoProvider, Action> allowedVmSizesServiceSetup, Action additionalActions = default) - => new(wrapAzureProxy: true, configuration: configuration, azureProxy: azureProxy, batchQuotaProvider: quotaProvider, batchSkuInformationProvider: skuInfoProvider, accountResourceInformation: GetNewBatchResourceInfo(), allowedVmSizesServiceSetup: allowedVmSizesServiceSetup, additionalActions: additionalActions); + private static TestServices.TestServiceProvider GetServiceProvider(IEnumerable<(string Key, string Value)> configuration, Action> azureProxy, Action> batchPoolManager, Action> quotaProvider, Action> skuInfoProvider, Action> allowedVmSizesServiceSetup, Action additionalActions = default) + => new(wrapAzureProxy: true, configuration: configuration, azureProxy: azureProxy, batchPoolManager: batchPoolManager, batchQuotaProvider: quotaProvider, batchSkuInformationProvider: skuInfoProvider, accountResourceInformation: GetNewBatchResourceInfo(), allowedVmSizesServiceSetup: allowedVmSizesServiceSetup, additionalActions: additionalActions); private static async Task GetNewTesTaskStateAsync(TesTask tesTask, AzureProxyReturnValues azureProxyReturnValues) { - _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), azureProxyReturnValues); + _ = await ProcessTesTaskAndGetBatchJobArgumentsAsync(tesTask, GetMockConfig()(), GetMockAzureProxy(azureProxyReturnValues), GetMockBatchPoolManager(azureProxyReturnValues), azureProxyReturnValues); return tesTask.State; } @@ -1252,6 +1303,16 @@ private static TesTask GetTesTask() return task; } + private static Action> GetMockBatchPoolManager(AzureProxyReturnValues azureProxyReturnValues) + => azureProxy => + { + azureProxy.Setup(a => a.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((BatchAccountPoolData p, bool _1, CancellationToken _2) => Task.FromResult(azureProxyReturnValues.CreateBatchPoolImpl(p))); + azureProxy.Setup(a => a.DeleteBatchPoolAsync(It.IsAny(), It.IsAny())) + .Callback((poolId, cancellationToken) => azureProxyReturnValues.DeleteBatchPoolImpl(poolId, cancellationToken)) + .Returns(Task.CompletedTask); + }; + private static Action> GetMockAzureProxy(AzureProxyReturnValues azureProxyReturnValues) => azureProxy => { @@ -1261,16 +1322,16 @@ private static Action> GetMockAzureProxy(AzureProxyReturnValue azureProxy.Setup(a => a.GetActivePoolsAsync(It.IsAny())) .Returns(AsyncEnumerable.Empty()); - azureProxy.Setup(a => a.GetBatchJobAndTaskStateAsync(It.IsAny(), It.IsAny())) + azureProxy.Setup(a => a.GetBatchJobAndTaskStateAsync(It.IsAny(), It.IsAny())) .Returns(Task.FromResult(azureProxyReturnValues.BatchJobAndTaskState)); - azureProxy.Setup(a => a.GetStorageAccountInfoAsync("defaultstorageaccount", It.IsAny())) + azureProxy.Setup(a => a.GetStorageAccountInfoAsync("defaultstorageaccount", It.IsAny())) .Returns(Task.FromResult(azureProxyReturnValues.StorageAccountInfos["defaultstorageaccount"])); - azureProxy.Setup(a => a.GetStorageAccountInfoAsync("storageaccount1", It.IsAny())) + azureProxy.Setup(a => a.GetStorageAccountInfoAsync("storageaccount1", It.IsAny())) .Returns(Task.FromResult(azureProxyReturnValues.StorageAccountInfos["storageaccount1"])); - azureProxy.Setup(a => a.GetStorageAccountKeyAsync(It.IsAny(), It.IsAny())) + azureProxy.Setup(a => a.GetStorageAccountKeyAsync(It.IsAny(), It.IsAny())) .Returns(Task.FromResult(azureProxyReturnValues.StorageAccountKey)); azureProxy.Setup(a => a.GetBatchActiveNodeCountByVmSize()) @@ -1282,16 +1343,13 @@ private static Action> GetMockAzureProxy(AzureProxyReturnValue azureProxy.Setup(a => a.GetBatchActivePoolCount()) .Returns(azureProxyReturnValues.ActivePoolCount); - azureProxy.Setup(a => a.GetBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((string id, System.Threading.CancellationToken cancellationToken, DetailLevel detailLevel) => Task.FromResult(azureProxyReturnValues.GetBatchPoolImpl(id))); + azureProxy.Setup(a => a.GetBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((string id, CancellationToken cancellationToken, DetailLevel detailLevel) => Task.FromResult(azureProxyReturnValues.GetBatchPoolImpl(id))); - azureProxy.Setup(a => a.DownloadBlobAsync(It.IsAny(), It.IsAny())) + azureProxy.Setup(a => a.DownloadBlobAsync(It.IsAny(), It.IsAny())) .Returns(Task.FromResult(azureProxyReturnValues.DownloadedBlobContent)); - azureProxy.Setup(a => a.CreateBatchPoolAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns((Pool p, bool _1, System.Threading.CancellationToken _2) => Task.FromResult(azureProxyReturnValues.CreateBatchPoolImpl(p))); - - azureProxy.Setup(a => a.GetFullAllocationStateAsync(It.IsAny(), It.IsAny())) + azureProxy.Setup(a => a.GetFullAllocationStateAsync(It.IsAny(), It.IsAny())) .Returns(Task.FromResult(azureProxyReturnValues.AzureProxyGetFullAllocationState?.Invoke() ?? new(null, null, null, null, null, null, null))); azureProxy.Setup(a => a.ListComputeNodesAsync(It.IsAny(), It.IsAny())) @@ -1299,10 +1357,6 @@ private static Action> GetMockAzureProxy(AzureProxyReturnValue => AsyncEnumerable.Empty() .Append(BatchPoolTests.GenerateNode(poolId, "ComputeNodeDedicated1", true, true)))); - azureProxy.Setup(a => a.DeleteBatchPoolAsync(It.IsAny(), It.IsAny())) - .Callback((poolId, cancellationToken) => azureProxyReturnValues.AzureProxyDeleteBatchPoolImpl(poolId, cancellationToken)) - .Returns(Task.CompletedTask); - azureProxy.Setup(a => a.ListTasksAsync(It.IsAny(), It.IsAny())) .Returns(azureProxyReturnValues.AzureProxyListTasks); }; @@ -1351,13 +1405,14 @@ private static TestServices.TestServiceProvider GetServiceProvi accountResourceInformation: new("defaultbatchaccount", "defaultresourcegroup", "defaultsubscription", "defaultregion", "defaultendpoint"), configuration: config, azureProxy: GetMockAzureProxy(azureProxyReturn), + batchPoolManager: GetMockBatchPoolManager(azureProxyReturn), batchQuotaProvider: GetMockQuotaProvider(azureProxyReturn), batchSkuInformationProvider: GetMockSkuInfoProvider(azureProxyReturn), allowedVmSizesServiceSetup: GetMockAllowedVms(config)); } private static async Task AddPool(BatchScheduler batchScheduler) - => (BatchPool)await batchScheduler.GetOrAddPoolAsync("key1", false, (id, cancellationToken) => ValueTask.FromResult(new(name: id, displayName: "display1", vmSize: "vmSize1")), System.Threading.CancellationToken.None); + => (BatchPool)await batchScheduler.GetOrAddPoolAsync("key1", false, (id, _1) => ValueTask.FromResult(BatchPoolTests.CreatePoolData(id, "display1", "vmSize1")), CancellationToken.None); internal static void GuardAssertsWithTesTask(TesTask tesTask, Action assertBlock) { @@ -1419,8 +1474,8 @@ private struct BatchJobAndTaskStates private class AzureProxyReturnValues { internal Func AzureProxyGetFullAllocationState { get; set; } - internal Action AzureProxyDeleteBatchPoolIfExists { get; set; } - internal Action AzureProxyDeleteBatchPool { get; set; } + internal Action AzureProxyDeleteBatchPoolIfExists { get; set; } + internal Action AzureProxyDeleteBatchPool { get; set; } internal Func> AzureProxyListTasks { get; set; } = (jobId, detail) => AsyncEnumerable.Empty(); public Dictionary StorageAccountInfos { get; set; } public List VmSizesAndPrices { get; set; } @@ -1434,7 +1489,7 @@ private class AzureProxyReturnValues public static AzureProxyReturnValues Defaults => new() { - AzureProxyGetFullAllocationState = () => new(Microsoft.Azure.Batch.Common.AllocationState.Steady, DateTime.MinValue.ToUniversalTime(), true, 0, 0, 0, 0), + AzureProxyGetFullAllocationState = () => new(AllocationState.Steady, DateTime.MinValue.ToUniversalTime(), true, 0, 0, 0, 0), AzureProxyDeleteBatchPoolIfExists = (poolId, cancellationToken) => { }, AzureProxyDeleteBatchPool = (poolId, cancellationToken) => { }, StorageAccountInfos = new() { @@ -1447,8 +1502,8 @@ private class AzureProxyReturnValues new() { VmSize = "VmSizeDedicated1", VmFamily = "VmFamily1", LowPriority = false, VCpusAvailable = 1, MemoryInGiB = 4, ResourceDiskSizeInGiB = 20, PricePerHour = 11 }, new() { VmSize = "VmSizeDedicated2", VmFamily = "VmFamily2", LowPriority = false, VCpusAvailable = 2, MemoryInGiB = 8, ResourceDiskSizeInGiB = 40, PricePerHour = 22 } }, - BatchQuotas = new() { ActiveJobAndJobScheduleQuota = 1, PoolQuota = 1, DedicatedCoreQuota = 5, LowPriorityCoreQuota = 10, DedicatedCoreQuotaPerVMFamily = new List() }, - ActiveNodeCountByVmSize = new List(), + BatchQuotas = new() { ActiveJobAndJobScheduleQuota = 1, PoolQuota = 1, DedicatedCoreQuota = 5, LowPriorityCoreQuota = 10, DedicatedCoreQuotaPerVMFamily = [] }, + ActiveNodeCountByVmSize = [], ActiveJobCount = 0, ActivePoolCount = 0, BatchJobAndTaskState = BatchJobAndTaskStates.JobNotFound, @@ -1465,7 +1520,7 @@ private static AzureProxyReturnValues DefaultsPerVMFamilyEnforcedImpl() proxy.BatchQuotas = new() { DedicatedCoreQuotaPerVMFamilyEnforced = true, - DedicatedCoreQuotaPerVMFamily = new VirtualMachineFamilyCoreQuota[] { new("VmFamily1", proxy.BatchQuotas.DedicatedCoreQuota), new("VmFamily2", 0), new("VmFamily3", 4) }, + DedicatedCoreQuotaPerVMFamily = [CreateBatchVmFamilyCoreQuota("VmFamily1", proxy.BatchQuotas.DedicatedCoreQuota), CreateBatchVmFamilyCoreQuota("VmFamily2", 0), CreateBatchVmFamilyCoreQuota("VmFamily3", 4)], DedicatedCoreQuota = proxy.BatchQuotas.DedicatedCoreQuota, ActiveJobAndJobScheduleQuota = proxy.BatchQuotas.ActiveJobAndJobScheduleQuota, LowPriorityCoreQuota = proxy.BatchQuotas.LowPriorityCoreQuota, @@ -1474,22 +1529,24 @@ private static AzureProxyReturnValues DefaultsPerVMFamilyEnforcedImpl() return proxy; } - private readonly Dictionary> poolMetadata = []; + private readonly Dictionary> poolMetadata = []; - internal void AzureProxyDeleteBatchPoolImpl(string poolId, System.Threading.CancellationToken cancellationToken) + internal void DeleteBatchPoolImpl(string poolId, CancellationToken cancellationToken) { _ = poolMetadata.Remove(poolId); AzureProxyDeleteBatchPool(poolId, cancellationToken); } - internal string CreateBatchPoolImpl(Pool pool) + internal string CreateBatchPoolImpl(BatchAccountPoolData pool) { - var poolId = pool.Name; + var poolNameItem = pool.Metadata.Single(i => string.IsNullOrEmpty(i.Name)); + pool.Metadata.Remove(poolNameItem); + var poolId = poolNameItem.Value; poolMetadata.Add(poolId, pool.Metadata?.Select(Convert).ToList()); return poolId; - static Microsoft.Azure.Batch.MetadataItem Convert(Microsoft.Azure.Management.Batch.Models.MetadataItem item) + static MetadataItem Convert(BatchAccountPoolMetadataItem item) => new(item.Name, item.Value); } @@ -1504,11 +1561,16 @@ internal CloudPool GetBatchPoolImpl(string poolId) } } + private static BatchVmFamilyCoreQuota CreateBatchVmFamilyCoreQuota(string name, int? quota) + { + return ArmBatchModelFactory.BatchVmFamilyCoreQuota(name, quota); + } + private class TestBatchQuotaVerifierQuotaMaxedOut : TestBatchQuotaVerifierBase { public TestBatchQuotaVerifierQuotaMaxedOut(IBatchQuotaProvider batchQuotaProvider) : base(batchQuotaProvider) { } - public override Task CheckBatchAccountQuotasAsync(VirtualMachineInformation _1, bool _2, System.Threading.CancellationToken cancellationToken) + public override Task CheckBatchAccountQuotasAsync(VirtualMachineInformation _1, bool _2, CancellationToken cancellationToken) => throw new AzureBatchQuotaMaxedOutException("Test AzureBatchQuotaMaxedOutException"); } @@ -1516,7 +1578,7 @@ private class TestBatchQuotaVerifierLowQuota : TestBatchQuotaVerifierBase { public TestBatchQuotaVerifierLowQuota(IBatchQuotaProvider batchQuotaProvider) : base(batchQuotaProvider) { } - public override Task CheckBatchAccountQuotasAsync(VirtualMachineInformation _1, bool _2, System.Threading.CancellationToken cancellationToken) + public override Task CheckBatchAccountQuotasAsync(VirtualMachineInformation _1, bool _2, CancellationToken cancellationToken) => throw new AzureBatchLowQuotaException("Test AzureBatchLowQuotaException"); } @@ -1527,7 +1589,7 @@ private abstract class TestBatchQuotaVerifierBase : IBatchQuotaVerifier protected TestBatchQuotaVerifierBase(IBatchQuotaProvider batchQuotaProvider) => this.batchQuotaProvider = batchQuotaProvider; - public abstract Task CheckBatchAccountQuotasAsync(VirtualMachineInformation virtualMachineInformation, bool needPoolOrJobQuotaCheck, System.Threading.CancellationToken cancellationToken); + public abstract Task CheckBatchAccountQuotasAsync(VirtualMachineInformation virtualMachineInformation, bool needPoolOrJobQuotaCheck, CancellationToken cancellationToken); public IBatchQuotaProvider GetBatchQuotaProvider() => batchQuotaProvider; diff --git a/src/TesApi.Tests/CachingWithRetriesAzureProxyTests.cs b/src/TesApi.Tests/CachingWithRetriesAzureProxyTests.cs index 051391eae..7048a15e1 100644 --- a/src/TesApi.Tests/CachingWithRetriesAzureProxyTests.cs +++ b/src/TesApi.Tests/CachingWithRetriesAzureProxyTests.cs @@ -24,7 +24,7 @@ public async Task GetStorageAccountKeyAsync_UsesCache() PrepareAzureProxy(a); a.Setup(a => a.GetStorageAccountKeyAsync(It.IsAny(), It.IsAny())).Returns(Task.FromResult(storageAccountKey)); }); - var cachingAzureProxy = serviceProvider.GetT(); + var cachingAzureProxy = (IAzureProxy)serviceProvider.GetT(); var key1 = await cachingAzureProxy.GetStorageAccountKeyAsync(storageAccountInfo, System.Threading.CancellationToken.None); var key2 = await cachingAzureProxy.GetStorageAccountKeyAsync(storageAccountInfo, System.Threading.CancellationToken.None); @@ -43,7 +43,7 @@ public async Task GetStorageAccountInfoAsync_UsesCache() PrepareAzureProxy(a); a.Setup(a => a.GetStorageAccountInfoAsync(It.IsAny(), It.IsAny())).Returns(Task.FromResult(storageAccountInfo)); }); - var cachingAzureProxy = serviceProvider.GetT(); + var cachingAzureProxy = (IAzureProxy)serviceProvider.GetT(); var info1 = await cachingAzureProxy.GetStorageAccountInfoAsync("defaultstorageaccount", System.Threading.CancellationToken.None); var info2 = await cachingAzureProxy.GetStorageAccountInfoAsync("defaultstorageaccount", System.Threading.CancellationToken.None); @@ -61,7 +61,7 @@ public async Task GetStorageAccountInfoAsync_NullInfo_DoesNotSetCache() PrepareAzureProxy(a); a.Setup(a => a.GetStorageAccountInfoAsync(It.IsAny(), It.IsAny())).Returns(Task.FromResult((StorageAccountInfo)null)); }); - var cachingAzureProxy = serviceProvider.GetT(); + var cachingAzureProxy = (IAzureProxy)serviceProvider.GetT(); var info1 = await cachingAzureProxy.GetStorageAccountInfoAsync("defaultstorageaccount", System.Threading.CancellationToken.None); var storageAccountInfo = new StorageAccountInfo { Name = "defaultstorageaccount", Id = "Id", BlobEndpoint = new("https://defaultstorageaccount/"), SubscriptionId = "SubId" }; @@ -83,7 +83,7 @@ public void GetBatchActivePoolCount_ThrowsException_RetriesThreeTimes() PrepareAzureProxy(a); a.Setup(a => a.GetBatchActivePoolCount()).Throws(); }); - var cachingAzureProxy = serviceProvider.GetT(); + var cachingAzureProxy = (IAzureProxy)serviceProvider.GetT(); Assert.ThrowsException(() => cachingAzureProxy.GetBatchActivePoolCount()); serviceProvider.AzureProxy.Verify(mock => mock.GetBatchActivePoolCount(), Times.Exactly(4)); @@ -99,7 +99,7 @@ public void GetBatchActiveJobCount_ThrowsException_RetriesThreeTimes() PrepareAzureProxy(a); a.Setup(a => a.GetBatchActiveJobCount()).Throws(); }); - var cachingAzureProxy = serviceProvider.GetT(); + var cachingAzureProxy = (IAzureProxy)serviceProvider.GetT(); Assert.ThrowsException(() => cachingAzureProxy.GetBatchActiveJobCount()); serviceProvider.AzureProxy.Verify(mock => mock.GetBatchActiveJobCount(), Times.Exactly(4)); diff --git a/src/TesApi.Tests/Repository/TesTaskPostgreSqlRepositoryIntegrationTests.cs b/src/TesApi.Tests/Repository/TesTaskPostgreSqlRepositoryIntegrationTests.cs index 5aeec2855..4f793783e 100644 --- a/src/TesApi.Tests/Repository/TesTaskPostgreSqlRepositoryIntegrationTests.cs +++ b/src/TesApi.Tests/Repository/TesTaskPostgreSqlRepositoryIntegrationTests.cs @@ -7,21 +7,18 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using Azure; +using Azure.Core; +using Azure.ResourceManager; +using Azure.ResourceManager.PostgreSql.FlexibleServers; +using Azure.ResourceManager.PostgreSql.FlexibleServers.Models; using CommonUtilities; using Microsoft.AspNetCore.Mvc; -using Microsoft.Azure.Management.AppService.Fluent; -using Microsoft.Azure.Management.PostgreSQL; -using Microsoft.Azure.Management.PostgreSQL.FlexibleServers; -using Microsoft.Azure.Management.ResourceManager.Fluent; -using Microsoft.Azure.Management.ResourceManager.Fluent.Authentication; -using Microsoft.Azure.Management.ResourceManager.Fluent.Core; using Microsoft.Extensions.Options; -using Microsoft.Rest; using Microsoft.VisualStudio.TestTools.UnitTesting; using Tes.Models; using Tes.Utilities; using TesApi.Controllers; -using FlexibleServer = Microsoft.Azure.Management.PostgreSQL.FlexibleServers; namespace Tes.Repository.Tests { @@ -433,41 +430,31 @@ public static async Task CreateTestDbAsync( { const string postgreSqlVersion = "14"; - ArgumentException.ThrowIfNullOrWhiteSpace(subscriptionId); - var tokenCredentials = new TokenCredentials(new RefreshableAzureServiceTokenProvider("https://management.azure.com/")); - var azureCredentials = new AzureCredentials(tokenCredentials, null, null, AzureEnvironment.AzureGlobalCloud); - var postgresManagementClient = new FlexibleServer.PostgreSQLManagementClient(azureCredentials) { SubscriptionId = subscriptionId, LongRunningOperationRetryTimeout = 1200 }; - var azureClient = GetAzureClient(azureCredentials); - var azureSubscriptionClient = azureClient.WithSubscription(subscriptionId); - - var rgs = (await azureSubscriptionClient.ResourceGroups.ListAsync()).ToList(); + var azureSubscriptionClient = GetArmClient(subscriptionId).GetDefaultSubscription(); - if (rgs.Any(r => r.Name.Equals(resourceGroupName, StringComparison.OrdinalIgnoreCase))) + if (await azureSubscriptionClient.GetResourceGroups().GetAllAsync() + .AnyAsync(r => r.Id.Name.Equals(resourceGroupName, StringComparison.OrdinalIgnoreCase))) { return; } - await azureSubscriptionClient - .ResourceGroups - .Define(resourceGroupName) - .WithRegion(regionName) - .CreateAsync(); - - await postgresManagementClient.Servers.CreateAsync( - resourceGroupName, - postgreSqlServerName, - new( - location: regionName, - version: postgreSqlVersion, - sku: new("Standard_B2s", "Burstable"), - storage: new(128), - administratorLogin: adminLogin, - administratorLoginPassword: adminPw, - network: new(publicNetworkAccess: "Enabled"), - highAvailability: new("Disabled") - )); - - await postgresManagementClient.Databases.CreateAsync(resourceGroupName, postgreSqlServerName, postgreSqlDatabaseName, new()); + var rg = (await azureSubscriptionClient.GetResourceGroups() + .CreateOrUpdateAsync(WaitUntil.Completed, resourceGroupName, new(new(regionName)))).Value; + + var server = (await rg.GetPostgreSqlFlexibleServers().CreateOrUpdateAsync(WaitUntil.Completed, postgreSqlServerName, new(new(regionName)) + { + Version = new(postgreSqlVersion), + Sku = new("Standard_B2s", PostgreSqlFlexibleServerSkuTier.Burstable), + StorageSizeInGB = 128, + AdministratorLogin = adminLogin, + AdministratorLoginPassword = adminPw, + //Network = new() { }, + HighAvailability = new() { Mode = PostgreSqlFlexibleServerHighAvailabilityMode.Disabled } + })).Value; + + var database = (await server.GetPostgreSqlFlexibleServerDatabases().CreateOrUpdateAsync(WaitUntil.Completed, postgreSqlDatabaseName, new())).Value; + + //var postgresManagementClient = new FlexibleServer.PostgreSQLManagementClient(azureCredentials) { SubscriptionId = subscriptionId, LongRunningOperationRetryTimeout = 1200 }; var startIp = "0.0.0.0"; var endIp = "255.255.255.255"; @@ -480,27 +467,44 @@ await postgresManagementClient.Servers.CreateAsync( endIp = ip; } - await postgresManagementClient.FirewallRules.CreateOrUpdateAsync( - resourceGroupName, - postgreSqlServerName, - "AllowTestMachine", - new FlexibleServer.Models.FirewallRule { StartIpAddress = startIp, EndIpAddress = endIp }); + Assert.IsFalse((await server.GetPostgreSqlFlexibleServerFirewallRules() + .CreateOrUpdateAsync(Azure.WaitUntil.Completed, + "AllowTestMachine", + new(System.Net.IPAddress.Parse(startIp), System.Net.IPAddress.Parse(endIp)))) + .GetRawResponse().IsError); } public static async Task DeleteResourceGroupAsync(string subscriptionId, string resourceGroupName) { - ArgumentException.ThrowIfNullOrWhiteSpace(subscriptionId); - var tokenCredentials = new TokenCredentials(new RefreshableAzureServiceTokenProvider("https://management.azure.com/")); - var azureCredentials = new AzureCredentials(tokenCredentials, null, null, AzureEnvironment.AzureGlobalCloud); - var azureClient = GetAzureClient(azureCredentials); - var azureSubscriptionClient = azureClient.WithSubscription(subscriptionId); - await azureSubscriptionClient.ResourceGroups.DeleteByNameAsync(resourceGroupName, CancellationToken.None); + var azureSubscriptionClient = GetArmClient(subscriptionId).GetDefaultSubscription(); + Assert.IsFalse((await azureSubscriptionClient.GetResourceGroups().Get(resourceGroupName, CancellationToken.None).Value + .DeleteAsync(WaitUntil.Completed, cancellationToken: CancellationToken.None)) + .GetRawResponse().IsError); } - private static Microsoft.Azure.Management.Fluent.Azure.IAuthenticated GetAzureClient(AzureCredentials azureCredentials) - => Microsoft.Azure.Management.Fluent.Azure - .Configure() - .WithLogLevel(HttpLoggingDelegatingHandler.Level.Basic) - .Authenticate(azureCredentials); + private static ArmClient GetArmClient(string subscriptionId) + { + ArgumentException.ThrowIfNullOrWhiteSpace(subscriptionId); + + Azure.Identity.DefaultAzureCredentialOptions credentialOptions = new() + { + AuthorityHost = Azure.Identity.AzureAuthorityHosts.AzurePublicCloud, + ExcludeManagedIdentityCredential = true, + ExcludeWorkloadIdentityCredential = true, + }; + + TokenCredential credentials = new Azure.Identity.DefaultAzureCredential(credentialOptions); + + ArmClientOptions clientOptions = new() + { + Environment = ArmEnvironment.AzurePublicCloud, + }; + clientOptions.Diagnostics.IsLoggingEnabled = true; + clientOptions.Retry.Mode = RetryMode.Exponential; + clientOptions.Retry.Delay = TimeSpan.FromSeconds(1); + clientOptions.Retry.MaxDelay = TimeSpan.FromSeconds(30); + clientOptions.Retry.MaxRetries = 10; + return new ArmClient(credentials, subscriptionId, clientOptions); + } } } diff --git a/src/TesApi.Tests/Runner/TaskToNodeTaskConverterTests.cs b/src/TesApi.Tests/Runner/TaskToNodeTaskConverterTests.cs index 6bd6aa64f..9e3366442 100644 --- a/src/TesApi.Tests/Runner/TaskToNodeTaskConverterTests.cs +++ b/src/TesApi.Tests/Runner/TaskToNodeTaskConverterTests.cs @@ -75,7 +75,7 @@ public void SetUp() x.GetBlobUrlsAsync(It.IsAny(), It.IsAny())) .Returns(Task.FromResult>([])); - var azureCloudIdentityConfig = AzureCloudConfig.CreateAsync().Result.AzureEnvironmentConfig; + var azureCloudIdentityConfig = AzureCloudConfig.FromKnownCloudNameAsync().Result.AzureEnvironmentConfig; taskToNodeTaskConverter = new TaskToNodeTaskConverter(Options.Create(terraOptions), storageAccessProviderMock.Object, Options.Create(storageOptions), Options.Create(batchAccountOptions), azureCloudIdentityConfig, new NullLogger()); } diff --git a/src/TesApi.Tests/StartupTests.cs b/src/TesApi.Tests/StartupTests.cs index 5078b2ba8..9ab1c9a98 100644 --- a/src/TesApi.Tests/StartupTests.cs +++ b/src/TesApi.Tests/StartupTests.cs @@ -24,7 +24,6 @@ namespace TesApi.Tests public class StartupTests { private Startup startup; - private Mock configurationMock; private Mock hostingEnvMock; private ServiceCollection services; private TerraApiStubData terraApiStubData; @@ -51,8 +50,9 @@ public void SetUp() options.Prefix = "TES-prefix"; }); - configurationMock = new Mock(); - configurationMock.Setup(c => c.GetSection(It.IsAny())).Returns(new Mock().Object); + ConfigurationBuilder builder = new(); + builder.AddInMemoryCollection([new("AzureServicesAuthConnectionString", $"RunAs=App;AppId={System.Guid.Empty:D}")]); + hostingEnvMock = new Mock(); hostingEnvMock.Setup(e => e.EnvironmentName).Returns("Development"); @@ -62,8 +62,11 @@ public void SetUp() services.AddSingleton(hostEnv.Object); #pragma warning restore CS0618 - Startup.AzureCloudConfig = AzureCloudConfig.CreateAsync().Result; - startup = new Startup(configurationMock.Object, NullLogger.Instance, hostingEnvMock.Object); + Startup.AzureCloudConfig = AzureCloudConfig.FromKnownCloudNameAsync().Result; + var configuration = builder.Build(); + services.AddSingleton(configuration); + services.AddSingleton(configuration); + startup = new Startup(configuration, NullLogger.Instance, hostingEnvMock.Object); } private void ConfigureTerraOptions() @@ -95,6 +98,19 @@ public void ConfigureServices_TerraOptionsAreConfigured_TerraStorageProviderIsRe Assert.IsInstanceOfType(terraStorageProvider, typeof(TerraStorageAccessProvider)); } + [TestMethod] + public void ConfigureServices_TerraOptionsAreNotConfigured_DefaultStorageProviderIsResolved() + { + startup.ConfigureServices(services); + + var serviceProvider = services.BuildServiceProvider(); + + var terraStorageProvider = serviceProvider.GetService(); + + Assert.IsNotNull(terraStorageProvider); + Assert.IsInstanceOfType(terraStorageProvider, typeof(DefaultStorageAccessProvider)); + } + [TestMethod] public void ConfigureServices_TerraOptionsAreConfigured_TerraBatchPoolManagerIsResolved() { @@ -107,7 +123,22 @@ public void ConfigureServices_TerraOptionsAreConfigured_TerraBatchPoolManagerIsR var poolManager = serviceProvider.GetService(); Assert.IsNotNull(poolManager); - Assert.IsInstanceOfType(poolManager, typeof(TerraBatchPoolManager)); + Assert.IsInstanceOfType(poolManager, typeof(CachingWithRetriesBatchPoolManager)); + Assert.IsInstanceOfType(poolManager.GetType().GetField("batchPoolManager", System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic).GetValue(poolManager), typeof(TerraBatchPoolManager)); + } + + [TestMethod] + public void ConfigureServices_TerraOptionsAreNotConfigured_ArmBatchPoolManagerIsResolved() + { + startup.ConfigureServices(services); + + var serviceProvider = services.BuildServiceProvider(); + + var poolManager = serviceProvider.GetService(); + + Assert.IsNotNull(poolManager); + Assert.IsInstanceOfType(poolManager, typeof(CachingWithRetriesBatchPoolManager)); + Assert.IsInstanceOfType(poolManager.GetType().GetField("batchPoolManager", System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic).GetValue(poolManager), typeof(ArmBatchPoolManager)); } [TestMethod] @@ -124,5 +155,18 @@ public void ConfigureServices_TerraOptionsAreConfigured_TerraQuotaVerifierIsReso Assert.IsNotNull(quotaProvider); Assert.IsInstanceOfType(quotaProvider, typeof(TerraQuotaProvider)); } + + [TestMethod] + public void ConfigureServices_TerraOptionsAreNotConfigured_ArmBatchQuotaVerifierIsResolved() + { + startup.ConfigureServices(services); + + var serviceProvider = services.BuildServiceProvider(); + + var quotaProvider = serviceProvider.GetService(); + + Assert.IsNotNull(quotaProvider); + Assert.IsInstanceOfType(quotaProvider, typeof(ArmBatchQuotaProvider)); + } } } diff --git a/src/TesApi.Tests/TerraBatchPoolManagerTests.cs b/src/TesApi.Tests/TerraBatchPoolManagerTests.cs index 3bb81882f..23cdbd099 100644 --- a/src/TesApi.Tests/TerraBatchPoolManagerTests.cs +++ b/src/TesApi.Tests/TerraBatchPoolManagerTests.cs @@ -6,7 +6,11 @@ using System.Linq; using System.Threading.Tasks; using AutoMapper; -using Microsoft.Azure.Management.Batch.Models; +using Azure.Core; +using Azure.ResourceManager.Batch; +using Azure.ResourceManager.Batch.Models; +using Azure.ResourceManager.Models; +using CommonUtilities; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -24,19 +28,18 @@ public class TerraBatchPoolManagerTests private TerraBatchPoolManager terraBatchPoolManager; private Mock wsmApiClientMock; private Mock> terraOptionsMock; - private Mock> batchAccountOptionsMock; + private Mock poolMetadataReaderMock; private TerraApiStubData terraApiStubData; private ApiCreateBatchPoolRequest capturedApiCreateBatchPoolRequest; [TestInitialize] public void SetUp() { - terraApiStubData = new TerraApiStubData(); - wsmApiClientMock = new Mock(); - terraOptionsMock = new Mock>(); + terraApiStubData = new(); + wsmApiClientMock = new(); + terraOptionsMock = new(); terraOptionsMock.Setup(x => x.Value).Returns(terraApiStubData.GetTerraOptions()); - batchAccountOptionsMock = new Mock>(); - batchAccountOptionsMock.Setup(x => x.Value).Returns(terraApiStubData.GetBatchAccountOptions()); + poolMetadataReaderMock = new(); wsmApiClientMock.Setup(x => x.CreateBatchPool(It.IsAny(), It.IsAny(), It.IsAny())) .Callback((arg1, arg2, arg3) => capturedApiCreateBatchPoolRequest = arg2) .ReturnsAsync(terraApiStubData.GetApiCreateBatchPoolResponse()); @@ -44,7 +47,7 @@ public void SetUp() var mapperCfg = new MapperConfiguration(cfg => cfg.AddProfile(typeof(MappingProfilePoolToWsmRequest))); terraBatchPoolManager = new TerraBatchPoolManager(wsmApiClientMock.Object, mapperCfg.CreateMapper(), - terraOptionsMock.Object, batchAccountOptionsMock.Object, NullLogger.Instance); + poolMetadataReaderMock.Object, terraOptionsMock.Object, NullLogger.Instance); } @@ -62,15 +65,15 @@ public void BatchPoolToWsmRequestMappingProfileIsValid() [TestMethod] public async Task CreateBatchPoolAsync_ValidResponse() { - var poolInfo = new Pool() + BatchAccountPoolData poolInfo = new() { - DeploymentConfiguration = new DeploymentConfiguration() + DeploymentConfiguration = new() { - CloudServiceConfiguration = new CloudServiceConfiguration("osfamily", "osversion"), - VirtualMachineConfiguration = new VirtualMachineConfiguration() + VmConfiguration = new BatchVmConfiguration(new BatchImageReference(), "batchNodeAgent"), }, - UserAccounts = new List() { new UserAccount("name", "password") } }; + poolInfo.UserAccounts.Add(new("name", "password")); + poolInfo.Metadata.Add(new(string.Empty, terraApiStubData.PoolId)); var poolId = await terraBatchPoolManager.CreateBatchPoolAsync(poolInfo, false, System.Threading.CancellationToken.None); @@ -83,19 +86,19 @@ public async Task CreateBatchPoolAsync_ValidResponse() [DataRow(true, 2)] public async Task CreateBatchPoolAsync_AddsResourceIdToMetadata(bool addPoolMetadata, int expectedMetadataLength) { - var poolInfo = new Pool() + BatchAccountPoolData poolInfo = new() { - DeploymentConfiguration = new DeploymentConfiguration() + DeploymentConfiguration = new() { - CloudServiceConfiguration = new CloudServiceConfiguration("osfamily", "osversion"), - VirtualMachineConfiguration = new VirtualMachineConfiguration() + VmConfiguration = new BatchVmConfiguration(new BatchImageReference(), "batchNodeAgent"), }, - UserAccounts = [new UserAccount("name", "password")] }; + poolInfo.UserAccounts.Add(new("name", "password")); + poolInfo.Metadata.Add(new(string.Empty, terraApiStubData.PoolId)); if (addPoolMetadata) { - poolInfo.Metadata = [new MetadataItem("name", "value")]; + poolInfo.Metadata.Add(new("name", "value")); } var pool = await terraBatchPoolManager.CreateBatchPoolAsync(poolInfo, false, System.Threading.CancellationToken.None); @@ -108,20 +111,21 @@ public async Task CreateBatchPoolAsync_AddsResourceIdToMetadata(bool addPoolMeta [TestMethod] public async Task CreateBatchPoolAsync_MultipleCallsHaveDifferentNameAndResourceId() { - var poolInfo = new Pool() + BatchAccountPoolData poolInfo = new() { - DeploymentConfiguration = new DeploymentConfiguration() + DeploymentConfiguration = new() { - CloudServiceConfiguration = new CloudServiceConfiguration("osfamily", "osversion"), - VirtualMachineConfiguration = new VirtualMachineConfiguration() + VmConfiguration = new BatchVmConfiguration(new BatchImageReference(), "batchNodeAgent"), }, }; + poolInfo.Metadata.Add(new(string.Empty, terraApiStubData.PoolId)); await terraBatchPoolManager.CreateBatchPoolAsync(poolInfo, false, System.Threading.CancellationToken.None); var name = capturedApiCreateBatchPoolRequest.Common.Name; var resourceId = capturedApiCreateBatchPoolRequest.Common.ResourceId; + poolInfo.Metadata.Add(new(string.Empty, terraApiStubData.PoolId)); await terraBatchPoolManager.CreateBatchPoolAsync(poolInfo, false, System.Threading.CancellationToken.None); Assert.AreNotEqual(name, capturedApiCreateBatchPoolRequest.Common.Name); @@ -131,23 +135,24 @@ public async Task CreateBatchPoolAsync_MultipleCallsHaveDifferentNameAndResource [TestMethod] public async Task CreateBatchPoolAsync_ValidUserIdentityResourceIdProvided_UserIdentityNameIsMapped() { - var identities = new Dictionary(); - var identityName = @"bar-identity"; var identityResourceId = $@"/subscriptions/aaaaa450-5f22-4b20-9326-b5852bb89d90/resourcegroups/foo/providers/Microsoft.ManagedIdentity/userAssignedIdentities/{identityName}"; + var identities = new Dictionary + { + { new(identityResourceId), new UserAssignedIdentity() } + }; - identities.Add(identityResourceId, new UserAssignedIdentities()); - - var poolInfo = new Pool() + BatchAccountPoolData poolInfo = new() { - DeploymentConfiguration = new DeploymentConfiguration() + DeploymentConfiguration = new() { - CloudServiceConfiguration = new CloudServiceConfiguration("osfamily", "osversion"), - VirtualMachineConfiguration = new VirtualMachineConfiguration() + VmConfiguration = new BatchVmConfiguration(new BatchImageReference(), "batchNodeAgent"), }, - Identity = new BatchPoolIdentity(PoolIdentityType.UserAssigned, identities) + Identity = new(ManagedServiceIdentityType.UserAssigned) }; + poolInfo.Identity.UserAssignedIdentities.AddRange(identities); + poolInfo.Metadata.Add(new(string.Empty, terraApiStubData.PoolId)); await terraBatchPoolManager.CreateBatchPoolAsync(poolInfo, false, System.Threading.CancellationToken.None); @@ -157,23 +162,24 @@ public async Task CreateBatchPoolAsync_ValidUserIdentityResourceIdProvided_UserI [DataTestMethod] [DataRow("/subscription/foo/bar-identity")] [DataRow("bar-identity")] - [DataRow("")] + [DataRow(" ")] public async Task CreateBatchPoolAsync_InvalidUserIdentityResourceIdProvided_ReturnsValueProvided(string identityName) { - var identities = new Dictionary + var identities = new Dictionary { - { identityName, new UserAssignedIdentities() } + { new(identityName), new UserAssignedIdentity() } }; - var poolInfo = new Pool() + BatchAccountPoolData poolInfo = new() { - DeploymentConfiguration = new DeploymentConfiguration() + DeploymentConfiguration = new() { - CloudServiceConfiguration = new CloudServiceConfiguration("osfamily", "osversion"), - VirtualMachineConfiguration = new VirtualMachineConfiguration() + VmConfiguration = new BatchVmConfiguration(new BatchImageReference(), "batchNodeAgent"), }, - Identity = new BatchPoolIdentity(PoolIdentityType.UserAssigned, identities) + Identity = new(ManagedServiceIdentityType.UserAssigned) }; + poolInfo.Identity.UserAssignedIdentities.AddRange(identities); + poolInfo.Metadata.Add(new(string.Empty, terraApiStubData.PoolId)); await terraBatchPoolManager.CreateBatchPoolAsync(poolInfo, false, System.Threading.CancellationToken.None); @@ -182,14 +188,14 @@ public async Task CreateBatchPoolAsync_InvalidUserIdentityResourceIdProvided_Ret [TestMethod] public async Task CreateBatchPoolAsync_NoUserIdentityResourceIdProvided_NoIdentitiesMapped() { - var poolInfo = new Pool() + BatchAccountPoolData poolInfo = new() { - DeploymentConfiguration = new DeploymentConfiguration() + DeploymentConfiguration = new() { - CloudServiceConfiguration = new CloudServiceConfiguration("osfamily", "osversion"), - VirtualMachineConfiguration = new VirtualMachineConfiguration() + VmConfiguration = new BatchVmConfiguration(new BatchImageReference(), "batchNodeAgent"), }, }; + poolInfo.Metadata.Add(new(string.Empty, terraApiStubData.PoolId)); await terraBatchPoolManager.CreateBatchPoolAsync(poolInfo, false, System.Threading.CancellationToken.None); @@ -199,15 +205,14 @@ public async Task CreateBatchPoolAsync_NoUserIdentityResourceIdProvided_NoIdenti [TestMethod] public async Task CreateBatchPoolAsync_UserIdentityInStartTaskMapsCorrectly() { - - var poolInfo = new Pool() + BatchAccountPoolData poolInfo = new() { - - StartTask = new StartTask() + StartTask = new() { - UserIdentity = new UserIdentity("user", new AutoUserSpecification(AutoUserScope.Pool, ElevationLevel.Admin)) + UserIdentity = new() { UserName = "user", AutoUser = new() { Scope = BatchAutoUserScope.Pool, ElevationLevel = BatchUserAccountElevationLevel.Admin } }, } }; + poolInfo.Metadata.Add(new(string.Empty, terraApiStubData.PoolId)); await terraBatchPoolManager.CreateBatchPoolAsync(poolInfo, false, System.Threading.CancellationToken.None); diff --git a/src/TesApi.Tests/TesApi.Tests.csproj b/src/TesApi.Tests/TesApi.Tests.csproj index f36a4e087..2049ff5f4 100644 --- a/src/TesApi.Tests/TesApi.Tests.csproj +++ b/src/TesApi.Tests/TesApi.Tests.csproj @@ -16,14 +16,11 @@ - - + + - - - - - + + diff --git a/src/TesApi.Tests/TestServices/TestServiceProvider.cs b/src/TesApi.Tests/TestServices/TestServiceProvider.cs index db3bde620..9558fb38f 100644 --- a/src/TesApi.Tests/TestServices/TestServiceProvider.cs +++ b/src/TesApi.Tests/TestServices/TestServiceProvider.cs @@ -6,7 +6,6 @@ using System.Linq; using System.Threading.Tasks; using CommonUtilities; -using CommonUtilities.AzureCloud; using CommonUtilities.Options; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Configuration; @@ -21,6 +20,7 @@ using Tes.Repository; using TesApi.Web; using TesApi.Web.Management; +using TesApi.Web.Management.Batch; using TesApi.Web.Management.Configuration; using TesApi.Web.Options; using TesApi.Web.Runner; @@ -38,6 +38,7 @@ internal TestServiceProvider( IEnumerable<(string Key, string Value)> configuration = default, BatchAccountResourceInformation accountResourceInformation = default, Action> azureProxy = default, + Action> batchPoolManager = default, Action>> tesTaskRepository = default, Action> storageAccessProvider = default, Action> batchSkuInformationProvider = default, @@ -94,6 +95,7 @@ internal TestServiceProvider( .AddSingleton() .AddSingleton() .AddSingleton() + .AddSingleton(GetBatchPoolManager(batchPoolManager).Object) .IfThenElse(additionalActions is null, s => { }, s => additionalActions(s)) .BuildServiceProvider(); @@ -109,6 +111,7 @@ internal TestServiceProvider( internal Mock> TesTaskRepository { get; private set; } internal Mock StorageAccessProvider { get; private set; } internal Mock AllowedVmSizesServiceProvider { get; private set; } + internal Mock BatchPoolManager { get; private set; } internal T GetT() => GetT([], []); @@ -171,6 +174,13 @@ private Mock GetAzureProxy(Action> action) return AzureProxy = proxy; } + private Mock GetBatchPoolManager(Action> action) + { + var proxy = new Mock(); + action?.Invoke(proxy); + return BatchPoolManager = proxy; + } + private Mock GetAllowedVmSizesServiceProviderProvider(Action> action) { var proxy = new Mock(); diff --git a/src/TesApi.Web/AzureBatchAccountQuotas.cs b/src/TesApi.Web/AzureBatchAccountQuotas.cs index 30efe5cfb..46ed986ce 100644 --- a/src/TesApi.Web/AzureBatchAccountQuotas.cs +++ b/src/TesApi.Web/AzureBatchAccountQuotas.cs @@ -2,7 +2,7 @@ // Licensed under the MIT License. using System.Collections.Generic; -using Microsoft.Azure.Management.Batch.Models; +using Azure.ResourceManager.Batch.Models; namespace TesApi.Web { @@ -15,6 +15,7 @@ public struct AzureBatchAccountQuotas /// Gets the active job and job schedule quota for the Batch account. /// public int ActiveJobAndJobScheduleQuota { get; set; } + /// /// Gets the dedicated core quota for the Batch account. /// @@ -23,12 +24,14 @@ public struct AzureBatchAccountQuotas /// on the subscription so this value is not returned. /// public int DedicatedCoreQuota { get; set; } + /// /// Gets a list of the dedicated core quota per Virtual Machine family for the Batch /// account. For accounts with PoolAllocationMode set to UserSubscription, quota /// is managed on the subscription so this value is not returned. /// - public IList DedicatedCoreQuotaPerVMFamily { get; set; } + public IReadOnlyList DedicatedCoreQuotaPerVMFamily { get; set; } + /// /// Gets a value indicating whether core quotas per Virtual Machine family are enforced /// for this account @@ -43,6 +46,7 @@ public struct AzureBatchAccountQuotas /// account, and the old dedicatedCoreQuota does not apply. /// public bool DedicatedCoreQuotaPerVMFamilyEnforced { get; set; } + /// /// Gets the low priority core quota for the Batch account. /// @@ -51,6 +55,7 @@ public struct AzureBatchAccountQuotas /// on the subscription so this value is not returned. /// public int LowPriorityCoreQuota { get; set; } + /// /// Gets the pool quota for the Batch account. /// diff --git a/src/TesApi.Web/AzureProxy.cs b/src/TesApi.Web/AzureProxy.cs index ea90a80fe..c8a00a71c 100644 --- a/src/TesApi.Web/AzureProxy.cs +++ b/src/TesApi.Web/AzureProxy.cs @@ -6,30 +6,28 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; -using Azure.Identity; +using Azure.Core; +using Azure.ResourceManager; +using Azure.ResourceManager.Resources; +using Azure.ResourceManager.Storage; +using Azure.Storage.Blobs; using CommonUtilities; using CommonUtilities.AzureCloud; using Microsoft.Azure.Batch; -using Microsoft.Azure.Batch.Auth; using Microsoft.Azure.Batch.Common; -using Microsoft.Azure.Management.ResourceManager.Fluent; -using Microsoft.Azure.Management.ResourceManager.Fluent.Authentication; -using Microsoft.Azure.Services.AppAuthentication; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Rest; -using Microsoft.WindowsAzure.Storage.Blob; using Polly; using Tes.Models; using TesApi.Web.Management; -using TesApi.Web.Management.Batch; using TesApi.Web.Management.Configuration; using TesApi.Web.Storage; using static CommonUtilities.RetryHandler; -using BatchModels = Microsoft.Azure.Management.Batch.Models; +using BatchProtocol = Microsoft.Azure.Batch.Protocol; +using BlobModels = Azure.Storage.Blobs.Models; using CloudTask = Microsoft.Azure.Batch.CloudTask; using ComputeNodeState = Microsoft.Azure.Batch.Common.ComputeNodeState; -using FluentAzure = Microsoft.Azure.Management.Fluent.Azure; using JobState = Microsoft.Azure.Batch.Common.JobState; using OnAllTasksComplete = Microsoft.Azure.Batch.Common.OnAllTasksComplete; using TaskExecutionInformation = Microsoft.Azure.Batch.TaskExecutionInformation; @@ -47,40 +45,41 @@ public partial class AzureProxy : IAzureProxy private readonly AsyncRetryHandlerPolicy batchRetryPolicyWhenNodeNotReady; private readonly ILogger logger; + private readonly AzureServicesConnectionStringCredentialOptions credentialOptions; + private readonly BatchProtocol.BatchServiceClient batchServiceClient; private readonly BatchClient batchClient; private readonly string location; - //TODO: This dependency should be injected at a higher level (e.g. scheduler), but that requires significant refactoring that should be done separately. - private readonly IBatchPoolManager batchPoolManager; - private readonly AzureCloudConfig azureCloudConfig; + private readonly ArmEnvironment armEnvironment; /// /// Constructor of AzureProxy /// /// The Azure Batch Account options /// The Azure Batch Account information - /// + /// /// /// Retry builder /// The logger /// - public AzureProxy(IOptions batchAccountOptions, BatchAccountResourceInformation batchAccountInformation, IBatchPoolManager batchPoolManager, AzureCloudConfig azureCloudConfig, RetryPolicyBuilder retryHandler, ILogger logger) + public AzureProxy(IOptions batchAccountOptions, BatchAccountResourceInformation batchAccountInformation, AzureServicesConnectionStringCredentialOptions credentialOptions, AzureCloudConfig azureCloudConfig, RetryPolicyBuilder retryHandler, ILogger logger) { ArgumentNullException.ThrowIfNull(batchAccountOptions); ArgumentNullException.ThrowIfNull(batchAccountInformation); ArgumentNullException.ThrowIfNull(logger); - ArgumentNullException.ThrowIfNull(batchPoolManager); + ArgumentNullException.ThrowIfNull(credentialOptions); ArgumentNullException.ThrowIfNull(retryHandler); ArgumentNullException.ThrowIfNull(logger); ArgumentNullException.ThrowIfNull(azureCloudConfig); - this.azureCloudConfig = azureCloudConfig; - this.batchPoolManager = batchPoolManager; + credentialOptions.AuthorityHost = azureCloudConfig.AuthorityHost; + + this.armEnvironment = azureCloudConfig.ArmEnvironment.Value; + this.credentialOptions = credentialOptions; this.logger = logger; if (string.IsNullOrWhiteSpace(batchAccountOptions.Value.AccountName)) { - //TODO: check if there's a better exception for this scenario or we need to create a custom one. - throw new InvalidOperationException("The batch account name is missing from the the configuration."); + throw new ArgumentException("The batch account name is missing from the the configuration.", nameof(batchAccountOptions)); } batchRetryPolicyWhenJobNotFound = retryHandler.PolicyBuilder @@ -95,18 +94,27 @@ public AzureProxy(IOptions batchAccountOptions, BatchAccoun .SetOnRetryBehavior(onRetry: LogRetryErrorOnRetryHandler()) .AsyncBuild(); + ServiceClientCredentials serviceClientCredentials = null; + if (!string.IsNullOrWhiteSpace(batchAccountOptions.Value.AppKey)) { //If the key is provided assume we won't use ARM and the information will be provided via config - batchClient = BatchClient.Open(new BatchSharedKeyCredentials(batchAccountOptions.Value.BaseUrl, - batchAccountOptions.Value.AccountName, batchAccountOptions.Value.AppKey)); + serviceClientCredentials = new BatchProtocol.BatchSharedKeyCredential( + batchAccountOptions.Value.AccountName, batchAccountOptions.Value.AppKey); location = batchAccountOptions.Value.Region; } else { location = batchAccountInformation.Region; - batchClient = BatchClient.Open(new BatchTokenCredentials(batchAccountInformation.BaseUrl, () => GetAzureAccessTokenAsync(CancellationToken.None, azureCloudConfig.Authentication.LoginEndpointUrl, azureCloudConfig.BatchUrl + "/.default"))); + var credentials = new AzureServicesConnectionStringCredential(credentialOptions); + serviceClientCredentials = new TokenCredentials(new BatchProtocol.BatchTokenProvider(async () => + (await credentials.GetTokenAsync(new TokenRequestContext( + [azureCloudConfig.BatchUrl.TrimEnd('/') + "/.default"], + tenantId: azureCloudConfig.Authentication.Tenant), CancellationToken.None)).Token)); } + + batchServiceClient = new(serviceClientCredentials) { BatchUrl = batchAccountInformation.BaseUrl }; + batchClient = BatchClient.Open(batchServiceClient); } /// @@ -117,30 +125,12 @@ private OnRetryHandler LogRetryErrorOnRetryHandler() => new((exception, timeSpan, retryCount, correlationId, caller) => { var requestId = (exception as BatchException)?.RequestInformation?.ServiceRequestId ?? "n/a"; - var reason = (exception.InnerException as Microsoft.Azure.Batch.Protocol.Models.BatchErrorException)?.Response?.ReasonPhrase ?? "n/a"; + var reason = (exception.InnerException as BatchProtocol.Models.BatchErrorException)?.Response?.ReasonPhrase ?? "n/a"; logger?.LogError(exception, @"Retrying in {Method}: RetryCount: {RetryCount} RetryCount: {TimeSpan:c} BatchErrorCode: '{BatchErrorCode}', ApiStatusCode '{ApiStatusCode}', Reason: '{ReasonPhrase}' ServiceRequestId: '{ServiceRequestId}', CorrelationId: {CorrelationId:D}", caller, retryCount, timeSpan, (exception as BatchException)?.RequestInformation?.BatchError?.Code ?? "n/a", (exception as BatchException)?.RequestInformation?.HttpStatusCode?.ToString("G") ?? "n/a", reason, requestId, correlationId); }); - /// - public async Task GetNextBatchJobIdAsync(string tesTaskId, CancellationToken cancellationToken) - { - var jobFilter = new ODATADetailLevel - { - FilterClause = $"startswith(id,'{tesTaskId}{BatchJobAttemptSeparator}')", - SelectClause = "id" - }; - - var lastAttemptNumber = await batchClient.JobOperations.ListJobs(jobFilter) - .ToAsyncEnumerable() - .Select(j => int.Parse(j.Id.Split(BatchJobAttemptSeparator)[1])) - .OrderBy(a => a) - .LastOrDefaultAsync(cancellationToken); - - return $"{tesTaskId}{BatchJobAttemptSeparator}{lastAttemptNumber + 1}"; - } - /// public IEnumerable GetBatchActiveNodeCountByVmSize() => batchClient.PoolOperations.ListPools() @@ -231,11 +221,7 @@ public async Task GetBatchJobAndTaskStateAsync(TesTas var attemptNumber = 0; CloudTask batchTask = null; - var jobOrTaskFilter = new ODATADetailLevel - { - FilterClause = $"startswith(id,'{tesTask.Id}{BatchJobAttemptSeparator}')", - SelectClause = "*" - }; + ODATADetailLevel jobOrTaskFilter = new(filterClause: $"startswith(id,'{tesTask.Id}{BatchJobAttemptSeparator}')", selectClause: "*"); if (string.IsNullOrWhiteSpace(tesTask.PoolId)) { @@ -246,7 +232,7 @@ public async Task GetBatchJobAndTaskStateAsync(TesTas { job = await batchClient.JobOperations.GetJobAsync(tesTask.PoolId, cancellationToken: cancellationToken); } - catch (BatchException ex) when (ex.InnerException is Microsoft.Azure.Batch.Protocol.Models.BatchErrorException e && e.Response.StatusCode == System.Net.HttpStatusCode.NotFound) + catch (BatchException ex) when (ex.InnerException is BatchProtocol.Models.BatchErrorException e && e.Response.StatusCode == System.Net.HttpStatusCode.NotFound) { logger.LogError(ex, @"Failed to get job for TesTask {TesTask}", tesTask.Id); return new AzureBatchJobAndTaskState { JobState = null }; @@ -274,8 +260,7 @@ public async Task GetBatchJobAndTaskStateAsync(TesTas poolId = job.ExecutionInformation?.PoolId; - Func computeNodePredicate = - n => (n.RecentTasks?.Select(t => t.TaskId) ?? Enumerable.Empty()).Contains(batchTask?.Id); + bool ComputeNodePredicate(ComputeNode n) => (n.RecentTasks?.Select(t => t.TaskId) ?? []).Contains(batchTask?.Id); var nodeId = string.Empty; @@ -292,14 +277,14 @@ public async Task GetBatchJobAndTaskStateAsync(TesTas { pool = await batchClient.PoolOperations.GetPoolAsync(poolId, poolFilter, cancellationToken: cancellationToken); } - catch (BatchException ex) when (ex.InnerException is Microsoft.Azure.Batch.Protocol.Models.BatchErrorException e && e.Response?.StatusCode == System.Net.HttpStatusCode.NotFound) + catch (BatchException ex) when (ex.InnerException is BatchProtocol.Models.BatchErrorException e && e.Response?.StatusCode == System.Net.HttpStatusCode.NotFound) { pool = default; } if (pool is not null) { - var node = await pool.ListComputeNodes().ToAsyncEnumerable().FirstOrDefaultAsync(computeNodePredicate, cancellationToken); + var node = await pool.ListComputeNodes().ToAsyncEnumerable().FirstOrDefaultAsync(ComputeNodePredicate, cancellationToken); if (node is not null) { @@ -356,19 +341,14 @@ public async Task GetBatchJobAndTaskStateAsync(TesTas /// public async Task DeleteBatchTaskAsync(string tesTaskId, string poolId, CancellationToken cancellationToken) { - var jobFilter = new ODATADetailLevel - { - FilterClause = $"startswith(id,'{tesTaskId}{BatchJobAttemptSeparator}')", - SelectClause = "id" - }; - + ODATADetailLevel jobFilter = new(filterClause: $"startswith(id,'{tesTaskId}{BatchJobAttemptSeparator}')", selectClause: "id"); List batchTasksToDelete = default; try { batchTasksToDelete = await batchClient.JobOperations.ListTasks(poolId, jobFilter).ToAsyncEnumerable().ToListAsync(cancellationToken); } - catch (BatchException ex) when (ex.InnerException is Microsoft.Azure.Batch.Protocol.Models.BatchErrorException bee && "JobNotFound".Equals(bee.Body?.Code, StringComparison.InvariantCultureIgnoreCase)) + catch (BatchException ex) when (ex.InnerException is BatchProtocol.Models.BatchErrorException bee && "JobNotFound".Equals(bee.Body?.Code, StringComparison.InvariantCultureIgnoreCase)) { logger.LogWarning("Job not found for TES task {TesTask}", tesTaskId); return; // Task cannot exist if the job is not found. @@ -389,12 +369,7 @@ public async Task DeleteBatchTaskAsync(string tesTaskId, string poolId, Cancella /// public async Task> GetActivePoolIdsAsync(string prefix, TimeSpan minAge, CancellationToken cancellationToken = default) { - var activePoolsFilter = new ODATADetailLevel - { - FilterClause = $"state eq 'active' and startswith(id, '{prefix}') and creationTime lt DateTime'{DateTime.UtcNow.Subtract(minAge):yyyy-MM-ddTHH:mm:ssZ}'", - SelectClause = "id" - }; - + ODATADetailLevel activePoolsFilter = new(filterClause: $"state eq 'active' and startswith(id, '{prefix}') and creationTime lt DateTime'{DateTime.UtcNow.Subtract(minAge):yyyy-MM-ddTHH:mm:ssZ}'", selectClause: "id"); return (await batchClient.PoolOperations.ListPools(activePoolsFilter).ToListAsync(cancellationToken)).Select(p => p.Id); } @@ -426,46 +401,6 @@ public async Task> GetPoolIdsReferencedByJobsAsync(Cancellat public Task DeleteBatchComputeNodesAsync(string poolId, IEnumerable computeNodes, CancellationToken cancellationToken = default) => batchClient.PoolOperations.RemoveFromPoolAsync(poolId, computeNodes, deallocationOption: ComputeNodeDeallocationOption.Requeue, resizeTimeout: TimeSpan.FromMinutes(30), cancellationToken: cancellationToken); - /// - public Task DeleteBatchPoolAsync(string poolId, CancellationToken cancellationToken = default) - => batchPoolManager.DeleteBatchPoolAsync(poolId, cancellationToken: cancellationToken); - - ///// - //public async Task DeleteBatchPoolIfExistsAsync(string poolId, CancellationToken cancellationToken = default) - //{ - // try - // { - // var poolFilter = new ODATADetailLevel - // { - // FilterClause = $"startswith(id,'{poolId}') and state ne 'deleting'", - // SelectClause = "id" - // }; - - // var poolsToDelete = await batchClient.PoolOperations.ListPools(poolFilter).ToListAsync(cancellationToken); - - // foreach (var pool in poolsToDelete) - // { - // logger.LogInformation($"Pool ID: {pool.Id} Pool State: {pool?.State} deleting..."); - // await batchClient.PoolOperations.DeletePoolAsync(pool.Id, cancellationToken: cancellationToken); - // } - // } - // catch (Exception exc) - // { - // var batchErrorCode = (exc as BatchException)?.RequestInformation?.BatchError?.Code; - - // if (batchErrorCode?.Trim().Equals("PoolBeingDeleted", StringComparison.OrdinalIgnoreCase) == true) - // { - // // Do not throw if it's a deletion race condition - // // Docs: https://learn.microsoft.com/en-us/rest/api/batchservice/Pool/Delete?tabs=HTTP - - // return; - // } - - // logger.LogError(exc, $"Pool ID: {poolId} exception while attempting to delete the pool. Batch error code: {batchErrorCode}"); - // throw; - // } - //} - /// public Task GetBatchPoolAsync(string poolId, CancellationToken cancellationToken = default, DetailLevel detailLevel = default) => batchClient.PoolOperations.GetPoolAsync(poolId, detailLevel: detailLevel, cancellationToken: cancellationToken); @@ -481,15 +416,11 @@ public async Task GetFullAllocationStateAsync(stri return new(pool.AllocationState, pool.AllocationStateTransitionTime, pool.AutoScaleEnabled, pool.TargetLowPriorityComputeNodes, pool.CurrentLowPriorityComputeNodes, pool.TargetDedicatedComputeNodes, pool.CurrentDedicatedComputeNodes); } - private static async Task> GetAccessibleStorageAccountsAsync(AzureCloudConfig azureCloudConfig, CancellationToken cancellationToken) + private IAsyncEnumerable GetAccessibleStorageAccountsAsync(CancellationToken cancellationToken) { - var azureClient = await GetAzureManagementClientAsync(azureCloudConfig, cancellationToken); - return (await azureClient.Subscriptions.ListAsync(cancellationToken: cancellationToken)) - .ToAsyncEnumerable() - .Select(s => s.SubscriptionId).SelectManyAwait(async (subscriptionId, ct) => - (await azureClient.WithSubscription(subscriptionId).StorageAccounts.ListAsync(cancellationToken: cancellationToken)) - .ToAsyncEnumerable() - .Select(a => new StorageAccountInfo { Id = a.Id, Name = a.Name, SubscriptionId = subscriptionId, BlobEndpoint = new(a.EndPoints.Primary.Blob) })); + var azureClient = GetAzureManagementClient(); + return azureClient.GetSubscriptions().SelectMany(s => s.GetStorageAccountsAsync(cancellationToken)).SelectAwaitWithCancellation(async (a, ct) => (await a.GetAsync(cancellationToken: ct)).Value) + .Select(a => new StorageAccountInfo { Id = a.Id, Name = a.Data.Name, SubscriptionId = a.Id.SubscriptionId, BlobEndpoint = a.Data.PrimaryEndpoints.BlobUri }); } /// @@ -497,10 +428,11 @@ public async Task GetStorageAccountKeyAsync(StorageAccountInfo storageAc { try { - var azureClient = await GetAzureManagementClientAsync(azureCloudConfig, cancellationToken); - var storageAccount = await azureClient.WithSubscription(storageAccountInfo.SubscriptionId).StorageAccounts.GetByIdAsync(storageAccountInfo.Id, cancellationToken); + ResourceIdentifier storageAccountId = new(storageAccountInfo.Id); + var azureClient = GetAzureManagementClient().GetResourceGroupResource(ResourceGroupResource.CreateResourceIdentifier(storageAccountId.SubscriptionId, storageAccountId.ResourceGroupName)); + var storageAccount = (await azureClient.GetStorageAccountAsync(storageAccountId.Name, cancellationToken: cancellationToken)).Value; - return (await storageAccount.GetKeysAsync(cancellationToken))[0].Value; + return (await storageAccount.GetKeysAsync(cancellationToken: cancellationToken).FirstAsync(key => Azure.ResourceManager.Storage.Models.StorageAccountKeyPermission.Full.Equals(key.Permissions), cancellationToken)).Value; } catch (Exception ex) { @@ -511,83 +443,65 @@ public async Task GetStorageAccountKeyAsync(StorageAccountInfo storageAc /// public Task UploadBlobAsync(Uri blobAbsoluteUri, string content, CancellationToken cancellationToken) - => new CloudBlockBlob(blobAbsoluteUri).UploadTextAsync(content, null, null, null, null, cancellationToken); + => new BlobClient(blobAbsoluteUri).UploadAsync(BinaryData.FromString(content), cancellationToken); /// public Task UploadBlobFromFileAsync(Uri blobAbsoluteUri, string filePath, CancellationToken cancellationToken) - => new CloudBlockBlob(blobAbsoluteUri).UploadFromFileAsync(filePath, null, null, null, cancellationToken); + { + using var stream = System.IO.File.OpenRead(filePath); + return new BlobClient(blobAbsoluteUri).UploadAsync(BinaryData.FromStream(stream), cancellationToken); + } /// - public Task DownloadBlobAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken) - => new CloudBlockBlob(blobAbsoluteUri).DownloadTextAsync(null, null, null, null, cancellationToken); + public async Task DownloadBlobAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken) + => (await new BlobClient(blobAbsoluteUri).DownloadContentAsync(cancellationToken)).Value.Content.ToString(); /// - public Task BlobExistsAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken) - => new CloudBlockBlob(blobAbsoluteUri).ExistsAsync(null, null, cancellationToken); + public async Task BlobExistsAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken) + => (await new BlobClient(blobAbsoluteUri).ExistsAsync(cancellationToken)).Value; /// - public async Task GetBlobPropertiesAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken) + public async Task GetBlobPropertiesAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken) { - var blob = new CloudBlockBlob(blobAbsoluteUri); + var blob = new BlobClient(blobAbsoluteUri); - if (await blob.ExistsAsync(null, null, cancellationToken)) + if ((await blob.ExistsAsync(cancellationToken)).Value) { - await blob.FetchAttributesAsync(null, null, null, cancellationToken); - return blob.Properties; + return (await blob.GetPropertiesAsync(cancellationToken: cancellationToken)).Value; } return default; } /// - public async Task> ListBlobsAsync(Uri directoryUri, CancellationToken cancellationToken) + public Task> ListBlobsAsync(Uri directoryUri, CancellationToken cancellationToken) { - var blob = new CloudBlockBlob(directoryUri); - var directory = blob.Container.GetDirectoryReference(blob.Name); + BlobUriBuilder uriBuilder = new(directoryUri); + var prefix = uriBuilder.BlobName + "/"; + uriBuilder.BlobName = null; + BlobContainerClient container = new(uriBuilder.ToUri()); - BlobContinuationToken continuationToken = null; - var results = new List(); - - do - { - var response = await directory.ListBlobsSegmentedAsync(useFlatBlobListing: true, blobListingDetails: BlobListingDetails.None, maxResults: null, currentToken: continuationToken, options: null, operationContext: null, cancellationToken: cancellationToken); - continuationToken = response.ContinuationToken; - results.AddRange(response.Results.OfType()); - } - while (continuationToken is not null); - - return results; + return Task.FromResult(container.GetBlobsAsync(prefix: prefix, cancellationToken: cancellationToken).ToBlockingEnumerable(cancellationToken)); } /// public string GetArmRegion() => location; - private static async Task GetAzureAccessTokenAsync(CancellationToken cancellationToken, string authorityHost, string scope) - => (await new DefaultAzureCredential(new DefaultAzureCredentialOptions { AuthorityHost = new Uri(authorityHost) }).GetTokenAsync(new Azure.Core.TokenRequestContext([scope]), cancellationToken)).Token; - /// /// Gets an authenticated Azure Client instance /// - /// - /// A for controlling the lifetime of the asynchronous operation. /// An authenticated Azure Client instance - private static async Task GetAzureManagementClientAsync(AzureCloudConfig azureCloudConfig, CancellationToken cancellationToken) + private ArmClient GetAzureManagementClient() { - var accessToken = await GetAzureAccessTokenAsync(cancellationToken, authorityHost: azureCloudConfig.Authentication.LoginEndpointUrl, scope: azureCloudConfig.DefaultTokenScope); - var azureCredentials = new AzureCredentials(new TokenCredentials(accessToken), null, null, azureCloudConfig.AzureEnvironment); - var azureClient = FluentAzure.Authenticate(azureCredentials); - - return azureClient; + return new(new AzureServicesConnectionStringCredential(credentialOptions), + default, + new ArmClientOptions { Environment = armEnvironment }); } - /// - public async Task CreateBatchPoolAsync(BatchModels.Pool poolSpec, bool isPreemptable, CancellationToken cancellationToken) - => await batchPoolManager.CreateBatchPoolAsync(poolSpec, isPreemptable, cancellationToken); - /// public async Task GetStorageAccountInfoAsync(string storageAccountName, CancellationToken cancellationToken) - => await (await GetAccessibleStorageAccountsAsync(azureCloudConfig, cancellationToken)) + => await GetAccessibleStorageAccountsAsync(cancellationToken) .FirstOrDefaultAsync(storageAccount => storageAccount.Name.Equals(storageAccountName, StringComparison.OrdinalIgnoreCase), cancellationToken); /// diff --git a/src/TesApi.Web/BatchPool.cs b/src/TesApi.Web/BatchPool.cs index 49b558d27..9747188e3 100644 --- a/src/TesApi.Web/BatchPool.cs +++ b/src/TesApi.Web/BatchPool.cs @@ -7,6 +7,7 @@ using System.Threading; using System.Threading.Tasks; using Azure; +using Azure.ResourceManager.Batch; using CommonUtilities; using Microsoft.Azure.Batch; using Microsoft.Azure.Batch.Common; @@ -35,6 +36,7 @@ public sealed partial class BatchPool private readonly ILogger _logger; private readonly IAzureProxy _azureProxy; + private readonly Management.Batch.IBatchPoolManager _batchPoolManager; private readonly IStorageAccessProvider _storageAccessProvider; /// @@ -43,18 +45,20 @@ public sealed partial class BatchPool /// /// /// + /// /// /// /// - public BatchPool(IBatchScheduler batchScheduler, IOptions batchSchedulingOptions, IAzureProxy azureProxy, IStorageAccessProvider storageAccessProvider, ILogger logger) + public BatchPool(IBatchScheduler batchScheduler, IOptions batchSchedulingOptions, IAzureProxy azureProxy, Management.Batch.IBatchPoolManager batchPoolManager, IStorageAccessProvider storageAccessProvider, ILogger logger) { var rotationDays = batchSchedulingOptions.Value.PoolRotationForcedDays; if (rotationDays == 0) { rotationDays = Options.BatchSchedulingOptions.DefaultPoolRotationForcedDays; } _forcePoolRotationAge = TimeSpan.FromDays(rotationDays); - this._azureProxy = azureProxy; - this._storageAccessProvider = storageAccessProvider; - this._logger = logger; + _azureProxy = azureProxy; + _batchPoolManager = batchPoolManager; + _storageAccessProvider = storageAccessProvider; + _logger = logger; _batchPools = batchScheduler as BatchScheduler ?? throw new ArgumentException("batchScheduler must be of type BatchScheduler", nameof(batchScheduler)); } @@ -662,16 +666,18 @@ public async ValueTask GetAllocationStateTransitionTime(CancellationTo => (await _azureProxy.GetBatchPoolAsync(PoolId, cancellationToken, new ODATADetailLevel { SelectClause = "allocationStateTransitionTime" })).AllocationStateTransitionTime ?? DateTime.UtcNow; /// - public async ValueTask CreatePoolAndJobAsync(Microsoft.Azure.Management.Batch.Models.Pool poolModel, bool isPreemptible, CancellationToken cancellationToken) + public async ValueTask CreatePoolAndJobAsync(BatchAccountPoolData poolModel, bool isPreemptible, CancellationToken cancellationToken) { + var jobId = poolModel.Metadata.Single(i => string.IsNullOrEmpty(i.Name)).Value; + try { CloudPool pool = default; await Task.WhenAll( - _azureProxy.CreateBatchJobAsync(poolModel.Name, poolModel.Name, cancellationToken), + _azureProxy.CreateBatchJobAsync(jobId, jobId, cancellationToken), Task.Run(async () => { - var poolId = await _azureProxy.CreateBatchPoolAsync(poolModel, isPreemptible, cancellationToken); + var poolId = await _batchPoolManager.CreateBatchPoolAsync(poolModel, isPreemptible, cancellationToken); pool = await _azureProxy.GetBatchPoolAsync(poolId, cancellationToken, new ODATADetailLevel { SelectClause = CloudPoolSelectClause }); }, cancellationToken)); @@ -698,8 +704,9 @@ Exception HandleException(Exception ex) { // When the batch management API creating the pool times out, it may or may not have created the pool. // Add an inactive record to delete it if it did get created and try again later. That record will be removed later whether or not the pool was created. - PoolId ??= poolModel.Name; + PoolId ??= jobId; _ = _batchPools.AddPool(this); + return ex switch { OperationCanceledException => ex.InnerException is null ? ex : new AzureBatchPoolCreationException(ex.Message, true, ex), diff --git a/src/TesApi.Web/BatchScheduler.BatchPools.cs b/src/TesApi.Web/BatchScheduler.BatchPools.cs index 9f2283d87..ed6cf1765 100644 --- a/src/TesApi.Web/BatchScheduler.BatchPools.cs +++ b/src/TesApi.Web/BatchScheduler.BatchPools.cs @@ -5,18 +5,20 @@ using System.Collections.Generic; using System.Collections.ObjectModel; using System.Linq; +using System.Net.Http; using System.Security.Cryptography; using System.Text; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; +using Azure.ResourceManager.Batch; using CommonUtilities; using Microsoft.Azure.Batch; using Microsoft.Azure.Batch.Common; using Microsoft.Extensions.Logging; using Tes.Models; +using TesApi.Web.Management.Batch; using static TesApi.Web.BatchScheduler.BatchPools; -using BatchModels = Microsoft.Azure.Management.Batch.Models; namespace TesApi.Web { @@ -25,7 +27,7 @@ public partial class BatchScheduler [GeneratedRegex("^[a-zA-Z0-9_-]+$")] private static partial Regex PoolNameRegex(); - internal delegate ValueTask ModelPoolFactory(string poolId, CancellationToken cancellationToken); + internal delegate ValueTask ModelPoolFactory(string poolId, CancellationToken cancellationToken); private (string PoolKey, string DisplayName) GetPoolKey(Tes.Models.TesTask tesTask, Tes.Models.VirtualMachineInformation virtualMachineInformation, List identities, CancellationToken cancellationToken) { @@ -146,7 +148,6 @@ internal async Task GetOrAddPoolAsync(string key, bool isPreemptable RandomNumberGenerator.Fill(uniquifier); var poolId = $"{key}-{uniquifier.ConvertToBase32().TrimEnd('=').ToLowerInvariant()}"; // embedded '-' is required by GetKeyFromPoolId() var modelPool = await modelPoolFactory(poolId, cancellationToken); - modelPool.Metadata ??= []; modelPool.Metadata.Add(new(PoolMetadata, new IBatchScheduler.PoolMetadata(this.batchPrefix, !isPreemptable, this.runnerMD5).ToString())); var batchPool = _batchPoolFactory.CreateNew(); await batchPool.CreatePoolAndJobAsync(modelPool, isPreemptable, cancellationToken); @@ -199,13 +200,15 @@ public Task DeletePoolAsync(IBatchPool pool, CancellationToken cancellationToken logger.LogDebug(@"Deleting pool and job {PoolId}", pool.PoolId); return Task.WhenAll( - AllowIfNotFound(azureProxy.DeleteBatchPoolAsync(pool.PoolId, cancellationToken)), + AllowIfNotFound(batchPoolManager.DeleteBatchPoolAsync(pool.PoolId, cancellationToken)), AllowIfNotFound(azureProxy.DeleteBatchJobAsync(pool.PoolId, cancellationToken))); static async Task AllowIfNotFound(Task task) { try { await task; } catch (BatchException ex) when (ex.InnerException is Microsoft.Azure.Batch.Protocol.Models.BatchErrorException e && e.Response.StatusCode == System.Net.HttpStatusCode.NotFound) { } + catch (HttpRequestException ex) when (ex.StatusCode == System.Net.HttpStatusCode.NotFound) { } + //catch (InvalidOperationException) { } // Terra providers may also throw this catch { throw; } } } diff --git a/src/TesApi.Web/BatchScheduler.cs b/src/TesApi.Web/BatchScheduler.cs index 316e92090..607daf28e 100644 --- a/src/TesApi.Web/BatchScheduler.cs +++ b/src/TesApi.Web/BatchScheduler.cs @@ -9,6 +9,9 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Azure.ResourceManager.Batch; +using Azure.Storage.Blobs; +using CommonUtilities; using Microsoft.Azure.Batch; using Microsoft.Azure.Batch.Common; using Microsoft.Extensions.Logging; @@ -18,10 +21,11 @@ using Tes.Models; using TesApi.Web.Extensions; using TesApi.Web.Management; +using TesApi.Web.Management.Batch; using TesApi.Web.Management.Models.Quotas; using TesApi.Web.Runner; using TesApi.Web.Storage; -using BatchModels = Microsoft.Azure.Management.Batch.Models; +using BatchModels = Azure.ResourceManager.Batch.Models; using TesException = Tes.Models.TesException; using TesFileType = Tes.Models.TesFileType; using TesInput = Tes.Models.TesInput; @@ -62,6 +66,7 @@ public partial class BatchScheduler : IBatchScheduler private const string NodeTaskRunnerMD5HashFilename = NodeTaskRunnerFilename + ".md5"; private readonly ILogger logger; private readonly IAzureProxy azureProxy; + private readonly IBatchPoolManager batchPoolManager; private readonly IStorageAccessProvider storageAccessProvider; private readonly IBatchQuotaVerifier quotaVerifier; private readonly IBatchSkuInformationProvider skuInformationProvider; @@ -97,6 +102,7 @@ public partial class BatchScheduler : IBatchScheduler /// Configuration of /// Configuration of /// Azure proxy + /// Azure batch pool management proxy /// Storage access provider /// Quota verifier > /// Sku information provider @@ -112,6 +118,7 @@ public BatchScheduler( IOptions batchNodesOptions, IOptions batchSchedulingOptions, IAzureProxy azureProxy, + IBatchPoolManager batchPoolManager, IStorageAccessProvider storageAccessProvider, IBatchQuotaVerifier quotaVerifier, IBatchSkuInformationProvider skuInformationProvider, @@ -129,6 +136,7 @@ public BatchScheduler( this.logger = logger; this.azureProxy = azureProxy; + this.batchPoolManager = batchPoolManager; this.storageAccessProvider = storageAccessProvider; this.quotaVerifier = quotaVerifier; this.skuInformationProvider = skuInformationProvider; @@ -284,7 +292,7 @@ private async Task AddProcessLogsIfAvailable(TesTask tesTask, CancellationToken // Get any logs the task runner left. Look for the latest set in this order: upload, exec, download foreach (var prefix in new[] { "upload_std", "exec_std", "download_std" }) { - var logs = FilterByPrefix(prefix, await azureProxy.ListBlobsAsync(directoryUri, cancellationToken)); + var logs = FilterByPrefix(directoryUri, prefix, await azureProxy.ListBlobsAsync(directoryUri, cancellationToken)); if (logs.Any()) { @@ -312,8 +320,8 @@ private async Task AddProcessLogsIfAvailable(TesTask tesTask, CancellationToken } #pragma warning disable IDE0305 // Simplify collection initialization - static IList<(Uri BlobUri, string[] BlobNameParts)> FilterByPrefix(string blobNameStartsWith, IEnumerable blobs) - => blobs.Select(blob => (BlobUri: new Azure.Storage.Blobs.BlobUriBuilder(blob.Uri) { Sas = null }.ToUri(), BlobName: blob.Name.Split('/').Last())) + static IList<(Uri BlobUri, string[] BlobNameParts)> FilterByPrefix(Uri directoryUri, string blobNameStartsWith, IEnumerable blobs) + => blobs.Select(blob => (BlobUri: new BlobUriBuilder(directoryUri) { Sas = null, BlobName = blob.Name }.ToUri(), BlobName: blob.Name.Split('/').Last())) .Where(blob => blob.BlobName.EndsWith(".txt") && blob.BlobName.StartsWith(blobNameStartsWith)) .Select(blob => (blob.BlobUri, BlobNameParts: blob.BlobName.Split('_', 4))) .OrderBy(blob => string.Join('_', blob.BlobNameParts.Take(3))) @@ -401,7 +409,7 @@ public async Task UploadTaskRunnerIfNeeded(CancellationToken cancellationToken) { var blobUri = await storageAccessProvider.GetInternalTesBlobUrlAsync(NodeTaskRunnerFilename, cancellationToken); var blobProperties = await azureProxy.GetBlobPropertiesAsync(blobUri, cancellationToken); - if (!runnerMD5.Equals(blobProperties?.ContentMD5, StringComparison.OrdinalIgnoreCase)) + if (!(await File.ReadAllTextAsync(Path.Combine(AppContext.BaseDirectory, $"scripts/{NodeTaskRunnerMD5HashFilename}"), cancellationToken)).Trim().Equals(Convert.ToBase64String(blobProperties?.ContentHash ?? []), StringComparison.OrdinalIgnoreCase)) { await azureProxy.UploadBlobFromFileAsync(blobUri, $"scripts/{NodeTaskRunnerFilename}", cancellationToken); } @@ -972,7 +980,7 @@ private async Task> GetExistingBlobsInCromwellStorageLocationAsTe var commandScriptPathParts = commandScript.Path.Split('/').ToList(); additionalInputFiles = await blobsInExecutionDirectory .Select(b => (Path: $"/{metadata.CromwellExecutionDir.TrimStart('/')}/{b.Name.Split('/').Last()}", - b.Uri)) + Uri: new BlobUriBuilder(executionDirectoryUri) { BlobName = b.Name }.ToUri())) .ToAsyncEnumerable() .SelectAwait(async b => new TesInput { @@ -1001,12 +1009,12 @@ enum StartScriptVmFamilies /// Constructs a universal Azure Start Task instance /// /// Pool Id - /// A describing the OS of the pool's nodes. + /// A describing the OS of the pool's nodes. /// /// A for controlling the lifetime of the asynchronous operation. /// /// This method also mitigates errors associated with docker daemons that are not configured to place their filesystem assets on the data drive. - private async Task GetStartTaskAsync(string poolId, BatchModels.VirtualMachineConfiguration machineConfiguration, string vmFamily, CancellationToken cancellationToken) + private async Task GetStartTaskAsync(string poolId, BatchModels.BatchVmConfiguration machineConfiguration, string vmFamily, CancellationToken cancellationToken) { ArgumentException.ThrowIfNullOrWhiteSpace(poolId); ArgumentNullException.ThrowIfNull(machineConfiguration); @@ -1075,7 +1083,7 @@ var s when s.StartsWith("batch.node.centos ", StringComparison.OrdinalIgnoreCase return new() { CommandLine = $"/bin/sh -c \"{CreateWgetDownloadCommand(await UploadScriptAsync(StartTaskScriptFilename, cmd), $"{BatchNodeTaskWorkingDirEnvVar}/{StartTaskScriptFilename}", true)} && {BatchNodeTaskWorkingDirEnvVar}/{StartTaskScriptFilename}\"", - UserIdentity = new BatchModels.UserIdentity(autoUser: new(elevationLevel: BatchModels.ElevationLevel.Admin, scope: BatchModels.AutoUserScope.Pool)), + UserIdentity = new() { AutoUser = new() { ElevationLevel = BatchModels.BatchUserAccountElevationLevel.Admin, Scope = BatchModels.BatchAutoUserScope.Pool } }, MaxTaskRetryCount = 1, WaitForSuccess = true }; @@ -1103,8 +1111,17 @@ async ValueTask ReadScript(string name) /// /// /// - private static BatchModels.BatchPoolIdentity GetBatchPoolIdentity(string[] identities) - => identities is null || !identities.Any() ? null : new(BatchModels.PoolIdentityType.UserAssigned, identities.ToDictionary(identity => identity, _ => new BatchModels.UserAssignedIdentities())); + private static Azure.ResourceManager.Models.ManagedServiceIdentity GetBatchPoolIdentity(string[] identities) + { + if (identities is null || identities.Length == 0) + { + return null; + } + + Azure.ResourceManager.Models.ManagedServiceIdentity result = new(Azure.ResourceManager.Models.ManagedServiceIdentityType.UserAssigned); + result.UserAssignedIdentities.AddRange(identities.ToDictionary(identity => new Azure.Core.ResourceIdentifier(identity), _ => new Azure.ResourceManager.Models.UserAssignedIdentity())); + return result; + } /// /// Generate the PoolSpecification for the needed pool. @@ -1118,47 +1135,52 @@ private static BatchModels.BatchPoolIdentity GetBatchPoolIdentity(string[] ident /// /// VM supports encryption at host. /// A for controlling the lifetime of the asynchronous operation. - /// A . + /// A . /// /// Devs: Any changes to any properties set in this method will require corresponding changes to all classes implementing along with possibly any systems they call, with the possible exception of . /// - private async ValueTask GetPoolSpecification(string name, string displayName, BatchModels.BatchPoolIdentity poolIdentity, string vmSize, string vmFamily, bool preemptable, BatchNodeInfo nodeInfo, bool? encryptionAtHostSupported, CancellationToken cancellationToken) + private async ValueTask GetPoolSpecification(string name, string displayName, Azure.ResourceManager.Models.ManagedServiceIdentity poolIdentity, string vmSize, string vmFamily, bool preemptable, BatchNodeInfo nodeInfo, bool? encryptionAtHostSupported, CancellationToken cancellationToken) { // TODO: (perpetually) add new properties we set in the future on and/or its contained objects, if possible. When not, update CreateAutoPoolModePoolInformation(). ValidateString(name, nameof(name), 64); ValidateString(displayName, nameof(displayName), 1024); - var vmConfig = new BatchModels.VirtualMachineConfiguration( - imageReference: new BatchModels.ImageReference( - publisher: nodeInfo.BatchImagePublisher, - offer: nodeInfo.BatchImageOffer, - sku: nodeInfo.BatchImageSku, - version: nodeInfo.BatchImageVersion), + var vmConfig = new BatchModels.BatchVmConfiguration( + imageReference: new() + { + Publisher = nodeInfo.BatchImagePublisher, + Offer = nodeInfo.BatchImageOffer, + Sku = nodeInfo.BatchImageSku, + Version = nodeInfo.BatchImageVersion, + }, nodeAgentSkuId: nodeInfo.BatchNodeAgentSkuId); if (encryptionAtHostSupported ?? false) { - vmConfig.DiskEncryptionConfiguration = new( - targets: new List { BatchModels.DiskEncryptionTarget.OsDisk, BatchModels.DiskEncryptionTarget.TemporaryDisk } - ); + vmConfig.DiskEncryptionTargets.AddRange([BatchModels.BatchDiskEncryptionTarget.OSDisk, BatchModels.BatchDiskEncryptionTarget.TemporaryDisk]); } - var poolSpecification = new BatchModels.Pool(name: name, displayName: displayName, identity: poolIdentity, vmSize: vmSize) + BatchAccountPoolData poolSpecification = new() { - ScaleSettings = new(autoScale: new(BatchPool.AutoPoolFormula(preemptable, 1), BatchPool.AutoScaleEvaluationInterval)), - DeploymentConfiguration = new(virtualMachineConfiguration: vmConfig), + DisplayName = displayName, + Identity = poolIdentity, + VmSize = vmSize, + ScaleSettings = new() { AutoScale = new(BatchPool.AutoPoolFormula(preemptable, 1)) { EvaluationInterval = BatchPool.AutoScaleEvaluationInterval } }, + DeploymentConfiguration = new() { VmConfiguration = vmConfig }, //ApplicationPackages = , StartTask = await GetStartTaskAsync(name, vmConfig, vmFamily, cancellationToken), TargetNodeCommunicationMode = BatchModels.NodeCommunicationMode.Simplified, }; + poolSpecification.Metadata.Add(new(string.Empty, name)); + if (!string.IsNullOrEmpty(batchNodesSubnetId)) { poolSpecification.NetworkConfiguration = new() { - PublicIPAddressConfiguration = new(provision: disableBatchNodesPublicIpAddress ? BatchModels.IPAddressProvisioningType.NoPublicIPAddresses : BatchModels.IPAddressProvisioningType.BatchManaged), - SubnetId = batchNodesSubnetId + PublicIPAddressConfiguration = new() { Provision = disableBatchNodesPublicIpAddress ? BatchModels.BatchIPAddressProvisioningType.NoPublicIPAddresses : BatchModels.BatchIPAddressProvisioningType.BatchManaged }, + SubnetId = new(batchNodesSubnetId), }; } diff --git a/src/TesApi.Web/CachingWithRetriesAzureProxy.cs b/src/TesApi.Web/CachingWithRetriesAzureProxy.cs index 698c69250..7b3e04006 100644 --- a/src/TesApi.Web/CachingWithRetriesAzureProxy.cs +++ b/src/TesApi.Web/CachingWithRetriesAzureProxy.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Threading; using System.Threading.Tasks; using CommonUtilities; @@ -13,71 +12,34 @@ using Microsoft.Extensions.Logging; using Tes.ApiClients; using TesApi.Web.Storage; -using static Tes.ApiClients.CachingRetryHandler; -using BatchModels = Microsoft.Azure.Management.Batch.Models; +using BlobModels = Azure.Storage.Blobs.Models; namespace TesApi.Web { /// /// Implements caching and retries for . /// - public class CachingWithRetriesAzureProxy : IAzureProxy + public class CachingWithRetriesAzureProxy : CachingWithRetriesBase, IAzureProxy { - private readonly ILogger logger; private readonly IAzureProxy azureProxy; - private readonly CachingRetryHandlerPolicy cachingRetry; - private readonly CachingAsyncRetryHandlerPolicy cachingAsyncRetry; - private readonly CachingAsyncRetryHandlerPolicy cachingAsyncRetryExceptWhenExists; - private readonly CachingAsyncRetryHandlerPolicy cachingAsyncRetryExceptWhenNotFound; /// - /// Contructor to create a cache of + /// Constructor to create a cache of /// /// /// /// public CachingWithRetriesAzureProxy(IAzureProxy azureProxy, CachingRetryPolicyBuilder cachingRetryHandler, ILogger logger) + : base(cachingRetryHandler, logger) { ArgumentNullException.ThrowIfNull(azureProxy); - ArgumentNullException.ThrowIfNull(cachingRetryHandler); this.azureProxy = azureProxy; - this.logger = logger; - - var sleepDuration = new Func((attempt, exception) => (exception as BatchException)?.RequestInformation?.RetryAfter); - - this.cachingRetry = cachingRetryHandler.PolicyBuilder.OpinionatedRetryPolicy() - .WithExceptionBasedWaitWithRetryPolicyOptionsBackup(sleepDuration, backupSkipProvidedIncrements: true).SetOnRetryBehavior(this.logger).AddCaching().SyncBuild(); - - this.cachingAsyncRetry = cachingRetryHandler.PolicyBuilder.OpinionatedRetryPolicy() - .WithExceptionBasedWaitWithRetryPolicyOptionsBackup(sleepDuration, backupSkipProvidedIncrements: true).SetOnRetryBehavior(this.logger).AddCaching().AsyncBuild(); - - this.cachingAsyncRetryExceptWhenExists = cachingRetryHandler.PolicyBuilder - .OpinionatedRetryPolicy(Polly.Policy.Handle(ex => !CreationErrorFoundCodes.Contains(ex.RequestInformation?.BatchError?.Code, StringComparer.OrdinalIgnoreCase))) - .WithExceptionBasedWaitWithRetryPolicyOptionsBackup(sleepDuration, backupSkipProvidedIncrements: true).SetOnRetryBehavior(this.logger).AddCaching().AsyncBuild(); - - this.cachingAsyncRetryExceptWhenNotFound = cachingRetryHandler.PolicyBuilder - .OpinionatedRetryPolicy(Polly.Policy.Handle(ex => !DeletionErrorFoundCodes.Contains(ex.RequestInformation?.BatchError?.Code, StringComparer.OrdinalIgnoreCase))) - .WithExceptionBasedWaitWithRetryPolicyOptionsBackup(sleepDuration, backupSkipProvidedIncrements: true).SetOnRetryBehavior(this.logger).AddCaching().AsyncBuild(); } - private static readonly string[] CreationErrorFoundCodes = new[] - { - BatchErrorCodeStrings.TaskExists, - BatchErrorCodeStrings.PoolExists, - BatchErrorCodeStrings.JobExists - }; - - private static readonly string[] DeletionErrorFoundCodes = new[] - { - BatchErrorCodeStrings.TaskNotFound, - BatchErrorCodeStrings.PoolNotFound, - BatchErrorCodeStrings.JobNotFound - }; - /// - public async Task CreateBatchJobAsync(string jobId, string poolId, CancellationToken cancellationToken) + async Task IAzureProxy.CreateBatchJobAsync(string jobId, string poolId, CancellationToken cancellationToken) { try { @@ -88,7 +50,7 @@ public async Task CreateBatchJobAsync(string jobId, string poolId, CancellationT } /// - public async Task AddBatchTaskAsync(string tesTaskId, CloudTask cloudTask, string jobId, CancellationToken cancellationToken) + async Task IAzureProxy.AddBatchTaskAsync(string tesTaskId, CloudTask cloudTask, string jobId, CancellationToken cancellationToken) { try { @@ -99,7 +61,7 @@ public async Task AddBatchTaskAsync(string tesTaskId, CloudTask cloudTask, strin } /// - public async Task DeleteBatchJobAsync(string jobId, CancellationToken cancellationToken) + async Task IAzureProxy.DeleteBatchJobAsync(string jobId, CancellationToken cancellationToken) { try { @@ -110,7 +72,7 @@ public async Task DeleteBatchJobAsync(string jobId, CancellationToken cancellati } /// - public async Task DeleteBatchTaskAsync(string taskId, string poolId, CancellationToken cancellationToken) + async Task IAzureProxy.DeleteBatchTaskAsync(string taskId, string poolId, CancellationToken cancellationToken) { try { @@ -121,64 +83,53 @@ public async Task DeleteBatchTaskAsync(string taskId, string poolId, Cancellatio } /// - public async Task DeleteBatchPoolAsync(string poolId, CancellationToken cancellationToken) - { - try - { - await cachingAsyncRetryExceptWhenNotFound.ExecuteWithRetryAsync(ct => azureProxy.DeleteBatchPoolAsync(poolId, ct), cancellationToken); - } - catch (BatchException exc) when (BatchErrorCodeStrings.TaskNotFound.Equals(exc.RequestInformation?.BatchError?.Code, StringComparison.OrdinalIgnoreCase)) - { } - } - - /// - public Task GetBatchPoolAsync(string poolId, CancellationToken cancellationToken, DetailLevel detailLevel) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.GetBatchPoolAsync(poolId, ct, detailLevel), cancellationToken); + Task IAzureProxy.GetBatchPoolAsync(string poolId, CancellationToken cancellationToken, DetailLevel detailLevel) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.GetBatchPoolAsync(poolId, ct, detailLevel), cancellationToken); /// - public Task GetBatchJobAsync(string jobId, CancellationToken cancellationToken, DetailLevel detailLevel) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.GetBatchJobAsync(jobId, ct, detailLevel), cancellationToken); + Task IAzureProxy.GetBatchJobAsync(string jobId, CancellationToken cancellationToken, DetailLevel detailLevel) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.GetBatchJobAsync(jobId, ct, detailLevel), cancellationToken); /// - public async Task DeleteBatchComputeNodesAsync(string poolId, IEnumerable computeNodes, CancellationToken cancellationToken) + async Task IAzureProxy.DeleteBatchComputeNodesAsync(string poolId, IEnumerable computeNodes, CancellationToken cancellationToken) { cachingAsyncRetry.AppCache.Remove($"{nameof(CachingWithRetriesAzureProxy)}:{poolId}"); await cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.DeleteBatchComputeNodesAsync(poolId, computeNodes, ct), cancellationToken); } /// - public Task DownloadBlobAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.DownloadBlobAsync(blobAbsoluteUri, ct), cancellationToken); + Task IAzureProxy.DownloadBlobAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.DownloadBlobAsync(blobAbsoluteUri, ct), cancellationToken); /// - public Task BlobExistsAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.BlobExistsAsync(blobAbsoluteUri, ct), cancellationToken); + Task IAzureProxy.BlobExistsAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.BlobExistsAsync(blobAbsoluteUri, ct), cancellationToken); /// - public Task> GetActivePoolIdsAsync(string prefix, TimeSpan minAge, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.GetActivePoolIdsAsync(prefix, minAge, ct), cancellationToken); + Task> IAzureProxy.GetActivePoolIdsAsync(string prefix, TimeSpan minAge, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.GetActivePoolIdsAsync(prefix, minAge, ct), cancellationToken); /// - public IAsyncEnumerable GetActivePoolsAsync(string hostName) => cachingRetry.ExecuteWithRetry(() => azureProxy.GetActivePoolsAsync(hostName)); + IAsyncEnumerable IAzureProxy.GetActivePoolsAsync(string hostName) => cachingRetry.ExecuteWithRetry(() => azureProxy.GetActivePoolsAsync(hostName)); /// - public int GetBatchActiveJobCount() => cachingRetry.ExecuteWithRetry(azureProxy.GetBatchActiveJobCount); + int IAzureProxy.GetBatchActiveJobCount() => cachingRetry.ExecuteWithRetry(azureProxy.GetBatchActiveJobCount); /// - public IEnumerable GetBatchActiveNodeCountByVmSize() => cachingRetry.ExecuteWithRetry(azureProxy.GetBatchActiveNodeCountByVmSize); + IEnumerable IAzureProxy.GetBatchActiveNodeCountByVmSize() => cachingRetry.ExecuteWithRetry(azureProxy.GetBatchActiveNodeCountByVmSize); /// - public int GetBatchActivePoolCount() => cachingRetry.ExecuteWithRetry(azureProxy.GetBatchActivePoolCount); + int IAzureProxy.GetBatchActivePoolCount() => cachingRetry.ExecuteWithRetry(azureProxy.GetBatchActivePoolCount); /// - public Task GetBatchJobAndTaskStateAsync(Tes.Models.TesTask tesTask, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.GetBatchJobAndTaskStateAsync(tesTask, ct), cancellationToken); + Task IAzureProxy.GetBatchJobAndTaskStateAsync(Tes.Models.TesTask tesTask, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.GetBatchJobAndTaskStateAsync(tesTask, ct), cancellationToken); /// - public Task> GetPoolIdsReferencedByJobsAsync(CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(azureProxy.GetPoolIdsReferencedByJobsAsync, cancellationToken); + Task> IAzureProxy.GetPoolIdsReferencedByJobsAsync(CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(azureProxy.GetPoolIdsReferencedByJobsAsync, cancellationToken); /// - public Task GetStorageAccountKeyAsync(StorageAccountInfo storageAccountInfo, CancellationToken cancellationToken) + Task IAzureProxy.GetStorageAccountKeyAsync(StorageAccountInfo storageAccountInfo, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAndCachingAsync($"{nameof(CachingWithRetriesAzureProxy)}:{storageAccountInfo.Id}", ct => azureProxy.GetStorageAccountKeyAsync(storageAccountInfo, ct), DateTimeOffset.Now.AddHours(1), cancellationToken); /// - public async Task GetStorageAccountInfoAsync(string storageAccountName, CancellationToken cancellationToken) + async Task IAzureProxy.GetStorageAccountInfoAsync(string storageAccountName, CancellationToken cancellationToken) { var cacheKey = $"{nameof(CachingWithRetriesAzureProxy)}:{storageAccountName}"; var storageAccountInfo = cachingAsyncRetry.AppCache.Get(cacheKey); @@ -197,50 +148,37 @@ public async Task GetStorageAccountInfoAsync(string storageA } /// - public Task> ListBlobsAsync(Uri directoryUri, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.ListBlobsAsync(directoryUri, ct), cancellationToken); + Task> IAzureProxy.ListBlobsAsync(Uri directoryUri, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.ListBlobsAsync(directoryUri, ct), cancellationToken); /// - public Task UploadBlobAsync(Uri blobAbsoluteUri, string content, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.UploadBlobAsync(blobAbsoluteUri, content, ct), cancellationToken); + Task IAzureProxy.UploadBlobAsync(Uri blobAbsoluteUri, string content, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.UploadBlobAsync(blobAbsoluteUri, content, ct), cancellationToken); /// - public Task UploadBlobFromFileAsync(Uri blobAbsoluteUri, string filePath, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.UploadBlobFromFileAsync(blobAbsoluteUri, filePath, ct), cancellationToken); + Task IAzureProxy.UploadBlobFromFileAsync(Uri blobAbsoluteUri, string filePath, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.UploadBlobFromFileAsync(blobAbsoluteUri, filePath, ct), cancellationToken); /// - public Task GetBlobPropertiesAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.GetBlobPropertiesAsync(blobAbsoluteUri, ct), cancellationToken); + Task IAzureProxy.GetBlobPropertiesAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAsync(ct => azureProxy.GetBlobPropertiesAsync(blobAbsoluteUri, ct), cancellationToken); /// - public string GetArmRegion() => azureProxy.GetArmRegion(); - - /// - public async Task CreateBatchPoolAsync(BatchModels.Pool poolSpec, bool isPreemptable, CancellationToken cancellationToken) - { - try - { - return await cachingAsyncRetryExceptWhenExists.ExecuteWithRetryAsync(ct => azureProxy.CreateBatchPoolAsync(poolSpec, isPreemptable, ct), cancellationToken); - } - catch (BatchException exc) when (BatchErrorCodeStrings.PoolExists.Equals(exc.RequestInformation?.BatchError?.Code, StringComparison.OrdinalIgnoreCase)) - { - return poolSpec.Name; - } - } + string IAzureProxy.GetArmRegion() => azureProxy.GetArmRegion(); /// - public Task GetFullAllocationStateAsync(string poolId, CancellationToken cancellationToken) + Task IAzureProxy.GetFullAllocationStateAsync(string poolId, CancellationToken cancellationToken) => cachingAsyncRetry.ExecuteWithRetryAndCachingAsync( $"{nameof(CachingWithRetriesAzureProxy)}:{poolId}", ct => azureProxy.GetFullAllocationStateAsync(poolId, ct), DateTimeOffset.Now.Add(BatchPoolService.RunInterval).Subtract(TimeSpan.FromSeconds(1)), cancellationToken); /// - public IAsyncEnumerable ListComputeNodesAsync(string poolId, DetailLevel detailLevel) => cachingAsyncRetry.ExecuteWithRetryAsync(() => azureProxy.ListComputeNodesAsync(poolId, detailLevel), cachingRetry); + IAsyncEnumerable IAzureProxy.ListComputeNodesAsync(string poolId, DetailLevel detailLevel) => cachingAsyncRetry.ExecuteWithRetryAsync(() => azureProxy.ListComputeNodesAsync(poolId, detailLevel), cachingRetry); /// - public IAsyncEnumerable ListTasksAsync(string jobId, DetailLevel detailLevel) => cachingAsyncRetry.ExecuteWithRetryAsync(() => azureProxy.ListTasksAsync(jobId, detailLevel), cachingRetry); + IAsyncEnumerable IAzureProxy.ListTasksAsync(string jobId, DetailLevel detailLevel) => cachingAsyncRetry.ExecuteWithRetryAsync(() => azureProxy.ListTasksAsync(jobId, detailLevel), cachingRetry); /// - public Task DisableBatchPoolAutoScaleAsync(string poolId, CancellationToken cancellationToken) => azureProxy.DisableBatchPoolAutoScaleAsync(poolId, cancellationToken); + Task IAzureProxy.DisableBatchPoolAutoScaleAsync(string poolId, CancellationToken cancellationToken) => azureProxy.DisableBatchPoolAutoScaleAsync(poolId, cancellationToken); /// - public Task EnableBatchPoolAutoScaleAsync(string poolId, bool preemptable, TimeSpan interval, IAzureProxy.BatchPoolAutoScaleFormulaFactory formulaFactory, CancellationToken cancellationToken) => azureProxy.EnableBatchPoolAutoScaleAsync(poolId, preemptable, interval, formulaFactory, cancellationToken); + Task IAzureProxy.EnableBatchPoolAutoScaleAsync(string poolId, bool preemptable, TimeSpan interval, IAzureProxy.BatchPoolAutoScaleFormulaFactory formulaFactory, CancellationToken cancellationToken) => azureProxy.EnableBatchPoolAutoScaleAsync(poolId, preemptable, interval, formulaFactory, cancellationToken); } } diff --git a/src/TesApi.Web/CachingWithRetriesBase.cs b/src/TesApi.Web/CachingWithRetriesBase.cs new file mode 100644 index 000000000..02a6550b7 --- /dev/null +++ b/src/TesApi.Web/CachingWithRetriesBase.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Linq; +using Microsoft.Azure.Batch.Common; +using Microsoft.Extensions.Logging; +using Tes.ApiClients; +using static Tes.ApiClients.CachingRetryHandler; + +namespace TesApi.Web +{ + /// + /// Common base for caching retry handlers + /// + public abstract class CachingWithRetriesBase + { + /// + /// Sync retry policy. + /// + protected readonly CachingRetryHandlerPolicy cachingRetry; + /// + /// Async retry policy. + /// + protected readonly CachingAsyncRetryHandlerPolicy cachingAsyncRetry; + /// + /// Async retry policy for methods where Exists should not be retried. + /// + protected readonly CachingAsyncRetryHandlerPolicy cachingAsyncRetryExceptWhenExists; + /// + /// Async retry policy for methods where NotFound should not be retried. + /// + protected readonly CachingAsyncRetryHandlerPolicy cachingAsyncRetryExceptWhenNotFound; + /// + /// Logger. + /// + protected readonly ILogger logger; + + /// + /// Constructor to create common retry handlers + /// + /// + /// + protected CachingWithRetriesBase(CachingRetryPolicyBuilder cachingRetryHandler, ILogger logger) + { + ArgumentNullException.ThrowIfNull(cachingRetryHandler); + + this.logger = logger; + + var sleepDuration = new Func((attempt, exception) => (exception as BatchException)?.RequestInformation?.RetryAfter); + + this.cachingRetry = cachingRetryHandler.PolicyBuilder + .OpinionatedRetryPolicy() + .WithExceptionBasedWaitWithRetryPolicyOptionsBackup(sleepDuration, backupSkipProvidedIncrements: true).SetOnRetryBehavior(this.logger).AddCaching().SyncBuild(); + + this.cachingAsyncRetry = cachingRetryHandler.PolicyBuilder + .OpinionatedRetryPolicy() + .WithExceptionBasedWaitWithRetryPolicyOptionsBackup(sleepDuration, backupSkipProvidedIncrements: true).SetOnRetryBehavior(this.logger).AddCaching().AsyncBuild(); + + this.cachingAsyncRetryExceptWhenExists = cachingRetryHandler.PolicyBuilder + .OpinionatedRetryPolicy(Polly.Policy.Handle(ex => !CreationErrorFoundCodes.Contains(ex.RequestInformation?.BatchError?.Code, StringComparer.OrdinalIgnoreCase))) + .WithExceptionBasedWaitWithRetryPolicyOptionsBackup(sleepDuration, backupSkipProvidedIncrements: true).SetOnRetryBehavior(this.logger).AddCaching().AsyncBuild(); + + this.cachingAsyncRetryExceptWhenNotFound = cachingRetryHandler.PolicyBuilder + .OpinionatedRetryPolicy(Polly.Policy.Handle(ex => !DeletionErrorFoundCodes.Contains(ex.RequestInformation?.BatchError?.Code, StringComparer.OrdinalIgnoreCase))) + .WithExceptionBasedWaitWithRetryPolicyOptionsBackup(sleepDuration, backupSkipProvidedIncrements: true).SetOnRetryBehavior(this.logger).AddCaching().AsyncBuild(); + } + + private static readonly string[] CreationErrorFoundCodes = + [ + BatchErrorCodeStrings.TaskExists, + BatchErrorCodeStrings.PoolExists, + BatchErrorCodeStrings.JobExists + ]; + + private static readonly string[] DeletionErrorFoundCodes = + [ + BatchErrorCodeStrings.TaskNotFound, + BatchErrorCodeStrings.PoolNotFound, + BatchErrorCodeStrings.JobNotFound + ]; + } +} diff --git a/src/TesApi.Web/IAzureProxy.cs b/src/TesApi.Web/IAzureProxy.cs index 8d70c330f..34246e5d0 100644 --- a/src/TesApi.Web/IAzureProxy.cs +++ b/src/TesApi.Web/IAzureProxy.cs @@ -6,10 +6,9 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Batch; -using Microsoft.Azure.Batch.Common; using Tes.Models; using TesApi.Web.Storage; -using BatchModels = Microsoft.Azure.Management.Batch.Models; +using BlobModels = Azure.Storage.Blobs.Models; namespace TesApi.Web { @@ -23,7 +22,7 @@ public interface IAzureProxy /// /// /// - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// A for controlling the lifetime of the asynchronous operation. Task CreateBatchJobAsync(string jobId, string poolId, CancellationToken cancellationToken); /// @@ -32,14 +31,14 @@ public interface IAzureProxy /// /// /// - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// A for controlling the lifetime of the asynchronous operation. Task AddBatchTaskAsync(string tesTaskId, CloudTask cloudTask, string jobId, CancellationToken cancellationToken); /// /// Terminates and deletes an Azure Batch job for /// /// - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// A for controlling the lifetime of the asynchronous operation. Task DeleteBatchJobAsync(string jobId, CancellationToken cancellationToken); /// @@ -50,15 +49,6 @@ public interface IAzureProxy /// Task GetStorageAccountInfoAsync(string storageAccountName, CancellationToken cancellationToken); - /// - /// Creates an Azure Batch pool who's lifecycle must be manually managed - /// - /// Contains the specification for the pool. - /// True if nodes in this pool will all be preemptable. False if nodes will all be dedicated. - /// A for controlling the lifetime of the asynchronous operation. - /// (from ) becomes the (aka ). - Task CreateBatchPoolAsync(BatchModels.Pool poolSpec, bool isPreemptable, CancellationToken cancellationToken); - /// /// Gets the combined state of Azure Batch job, task and pool that corresponds to the given TES task /// @@ -72,7 +62,7 @@ public interface IAzureProxy /// /// The unique TES task ID /// - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// A for controlling the lifetime of the asynchronous operation. Task DeleteBatchTaskAsync(string taskId, string jobId, CancellationToken cancellationToken); /// @@ -141,7 +131,8 @@ public interface IAzureProxy /// Directory Uri /// A for controlling the lifetime of the asynchronous operation. /// List of blob paths - Task> ListBlobsAsync(Uri directoryUri, CancellationToken cancellationToken); + // TODO: change return to IAsyncEnumerable + Task> ListBlobsAsync(Uri directoryUri, CancellationToken cancellationToken); /// /// Fetches the blobs properties @@ -149,14 +140,14 @@ public interface IAzureProxy /// Absolute Blob URI /// A for controlling the lifetime of the asynchronous operation. /// - Task GetBlobPropertiesAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken); + Task GetBlobPropertiesAsync(Uri blobAbsoluteUri, CancellationToken cancellationToken); /// /// Gets the list of active pool ids matching the prefix and with creation time older than the minAge /// /// /// - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// A for controlling the lifetime of the asynchronous operation. /// Active pool ids Task> GetActivePoolIdsAsync(string prefix, TimeSpan minAge, CancellationToken cancellationToken); @@ -170,22 +161,15 @@ public interface IAzureProxy /// /// Gets the list of pool ids referenced by the jobs /// - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// A for controlling the lifetime of the asynchronous operation. /// Pool ids Task> GetPoolIdsReferencedByJobsAsync(CancellationToken cancellationToken); - /// - /// Deletes the specified pool - /// - /// The id of the pool. - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. - Task DeleteBatchPoolAsync(string poolId, CancellationToken cancellationToken); - /// /// Retrieves the specified pool /// /// The of the pool to retrieve. - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// A for controlling the lifetime of the asynchronous operation. /// A Microsoft.Azure.Batch.DetailLevel used for controlling which properties are retrieved from the service. /// Task GetBatchPoolAsync(string poolId, CancellationToken cancellationToken, DetailLevel detailLevel = default); @@ -194,7 +178,7 @@ public interface IAzureProxy /// Retrieves the specified batch job. /// /// The of the job to retrieve. - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// A for controlling the lifetime of the asynchronous operation. /// A Microsoft.Azure.Batch.DetailLevel used for controlling which properties are retrieved from the service. /// Task GetBatchJobAsync(string jobId, CancellationToken cancellationToken, DetailLevel detailLevel = default); @@ -220,7 +204,7 @@ public interface IAzureProxy /// /// The id of the pool. /// Enumerable list of s to delete. - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// A for controlling the lifetime of the asynchronous operation. /// Task DeleteBatchComputeNodesAsync(string poolId, IEnumerable computeNodes, CancellationToken cancellationToken); @@ -228,7 +212,7 @@ public interface IAzureProxy /// Gets the allocation state and numbers of targeted and current compute nodes /// /// The id of the pool. - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// A for controlling the lifetime of the asynchronous operation. /// Task GetFullAllocationStateAsync(string poolId, CancellationToken cancellationToken); @@ -242,7 +226,7 @@ public interface IAzureProxy /// Disables AutoScale in a Batch Pool /// /// The id of the pool. - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// A for controlling the lifetime of the asynchronous operation. /// Task DisableBatchPoolAutoScaleAsync(string poolId, CancellationToken cancellationToken); @@ -253,7 +237,7 @@ public interface IAzureProxy /// Type of compute nodes: false if dedicated, otherwise true. /// The interval for periodic reevaluation of the formula. /// A factory function that generates an auto-scale formula. - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// A for controlling the lifetime of the asynchronous operation. /// Task EnableBatchPoolAutoScaleAsync(string poolId, bool preemptable, TimeSpan interval, BatchPoolAutoScaleFormulaFactory formulaFactory, CancellationToken cancellationToken); diff --git a/src/TesApi.Web/IBatchPool.cs b/src/TesApi.Web/IBatchPool.cs index 9df54e727..312e00b18 100644 --- a/src/TesApi.Web/IBatchPool.cs +++ b/src/TesApi.Web/IBatchPool.cs @@ -29,7 +29,7 @@ public interface IBatchPool /// /// /// - ValueTask CreatePoolAndJobAsync(Microsoft.Azure.Management.Batch.Models.Pool pool, bool isPreemptible, CancellationToken cancellationToken); + ValueTask CreatePoolAndJobAsync(Azure.ResourceManager.Batch.BatchAccountPoolData pool, bool isPreemptible, CancellationToken cancellationToken); /// /// Connects to the provided pool and associated job in the Batch Account. diff --git a/src/TesApi.Web/Management/ArmBatchQuotaProvider.cs b/src/TesApi.Web/Management/ArmBatchQuotaProvider.cs index 40066eecb..f47dc02ef 100644 --- a/src/TesApi.Web/Management/ArmBatchQuotaProvider.cs +++ b/src/TesApi.Web/Management/ArmBatchQuotaProvider.cs @@ -6,7 +6,6 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; -using Microsoft.Azure.Management.Batch; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Logging; using TesApi.Web.Management.Models.Quotas; @@ -63,7 +62,7 @@ public async Task GetVmCoreQuotaAsync(bool lowPriority, Cancel if (isDedicatedAndPerVmFamilyCoreQuotaEnforced) { dedicatedCoresPerFamilies = batchQuota.DedicatedCoreQuotaPerVMFamily - .Select(r => new BatchVmCoresPerFamily(r.Name, r.CoreQuota)) + .Select(r => new BatchVmCoresPerFamily(r.Name, r.CoreQuota ?? 0)) .ToList(); } @@ -79,7 +78,11 @@ public async Task GetVmCoreQuotaAsync(bool lowPriority, Cancel /// /// public virtual async Task GetBatchAccountQuotasAsync(CancellationToken cancellationToken) - => await appCache.GetOrCreateAsync(clientsFactory.BatchAccountInformation.ToString(), _1 => GetBatchAccountQuotasImplAsync(cancellationToken)); // TODO: Consider expiring the quota daily, because quota can be changed. + => await appCache.GetOrCreateAsync(clientsFactory.BatchAccountInformation.ToString(), entry => + { + entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromDays(1); + return GetBatchAccountQuotasImplAsync(cancellationToken); + }); private async Task GetBatchAccountQuotasImplAsync(CancellationToken cancellationToken) { @@ -87,25 +90,21 @@ private async Task GetBatchAccountQuotasImplAsync(Cance { logger.LogInformation($"Getting quota information for Batch Account: {clientsFactory.BatchAccountInformation.Name} calling ARM API"); - using var managementClient = await clientsFactory.CreateBatchAccountManagementClient(cancellationToken); - var batchAccount = await managementClient.BatchAccount.GetAsync(clientsFactory.BatchAccountInformation.ResourceGroupName, clientsFactory.BatchAccountInformation.Name, cancellationToken: cancellationToken); - - if (batchAccount == null) - { - throw new InvalidOperationException( - $"Batch Account was not found. Account name:{clientsFactory.BatchAccountInformation.Name}. Resource group:{clientsFactory.BatchAccountInformation.ResourceGroupName}"); - } - - return new AzureBatchAccountQuotas - { - ActiveJobAndJobScheduleQuota = batchAccount.ActiveJobAndJobScheduleQuota, - DedicatedCoreQuota = batchAccount.DedicatedCoreQuota ?? 0, - DedicatedCoreQuotaPerVMFamily = batchAccount.DedicatedCoreQuotaPerVMFamily, - DedicatedCoreQuotaPerVMFamilyEnforced = batchAccount.DedicatedCoreQuotaPerVMFamilyEnforced, - LowPriorityCoreQuota = batchAccount.LowPriorityCoreQuota ?? 0, - PoolQuota = batchAccount.PoolQuota, - - }; + var managementClient = clientsFactory.CreateBatchAccountManagementClient(); + var batchAccount = (await managementClient.GetAsync(cancellationToken: cancellationToken)).Value.Data; + + return batchAccount is null + ? throw new InvalidOperationException( + $"Batch Account was not found. Account name:{clientsFactory.BatchAccountInformation.Name}. Resource group:{clientsFactory.BatchAccountInformation.ResourceGroupName}") + : new AzureBatchAccountQuotas + { + ActiveJobAndJobScheduleQuota = batchAccount.ActiveJobAndJobScheduleQuota ?? 0, + DedicatedCoreQuota = batchAccount.DedicatedCoreQuota ?? 0, + DedicatedCoreQuotaPerVMFamily = batchAccount.DedicatedCoreQuotaPerVmFamily, + DedicatedCoreQuotaPerVMFamilyEnforced = batchAccount.IsDedicatedCoreQuotaPerVmFamilyEnforced ?? false, + LowPriorityCoreQuota = batchAccount.LowPriorityCoreQuota ?? 0, + PoolQuota = batchAccount.PoolQuota ?? 0, + }; } catch (Exception ex) { diff --git a/src/TesApi.Web/Management/ArmResourceInformationFinder.cs b/src/TesApi.Web/Management/ArmResourceInformationFinder.cs index 1bbc76b5e..2fd18740f 100644 --- a/src/TesApi.Web/Management/ArmResourceInformationFinder.cs +++ b/src/TesApi.Web/Management/ArmResourceInformationFinder.cs @@ -2,16 +2,12 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; -using Azure.Identity; -using CommonUtilities; -using CommonUtilities.AzureCloud; -using Microsoft.Azure.Management.ApplicationInsights.Management; -using Microsoft.Azure.Management.Batch; -using Microsoft.Rest; -using TesApi.Web.Extensions; +using Azure; +using Azure.ResourceManager; namespace TesApi.Web.Management { @@ -23,18 +19,21 @@ public static class ArmResourceInformationFinder /// /// Looks up the AppInsights instrumentation key in subscriptions the TES services has access to /// - /// - /// Azure cloud identity configuration + /// AppInsights account name + /// A credential capable of providing an OAuth token. + /// The information of an Azure Cloud environment. + /// A for controlling the lifetime of the asynchronous operation. /// - /// - public static Task GetAppInsightsConnectionStringAsync(string accountName, AzureCloudConfig azureCloudConfig, CancellationToken cancellationToken) + public static Task GetAppInsightsConnectionStringFromAccountNameAsync(string accountName, Azure.Core.TokenCredential tokenCredential, ArmEnvironment armEnvironment, CancellationToken cancellationToken) { + ArgumentException.ThrowIfNullOrEmpty(accountName); + return GetAzureResourceAsync( - clientFactory: (tokenCredentials, subscription) => new ApplicationInsightsManagementClient(tokenCredentials) { SubscriptionId = subscription, BaseUri = new Uri(azureCloudConfig.ResourceManagerUrl) }, - listAsync: (client, ct) => client.Components.ListAsync(ct), - listNextAsync: (client, link, ct) => client.Components.ListNextAsync(link, ct), + tokenCredential, armEnvironment, + listAsync: Azure.ResourceManager.ApplicationInsights.ApplicationInsightsExtensions.GetApplicationInsightsComponentsAsync, + getDataAsync: async (subscriptionResource, token) => await subscriptionResource.GetAsync(token), + getData: subscriptionResource => subscriptionResource.Data, predicate: a => a.ApplicationId.Equals(accountName, StringComparison.OrdinalIgnoreCase), - azureCloudConfig: azureCloudConfig, cancellationToken: cancellationToken, finalize: a => a.ConnectionString); } @@ -43,76 +42,71 @@ public static Task GetAppInsightsConnectionStringAsync(string accountNam /// Attempts to get the batch resource information using the ARM api. /// Returns null if the resource was not found or the account does not have access. /// - /// batch account name - /// Azure cloud identity configuration + /// Batch account name + /// A credential capable of providing an OAuth token. + /// The information of an Azure Cloud environment. + /// A for controlling the lifetime of the asynchronous operation. /// - /// - public static Task TryGetResourceInformationFromAccountNameAsync(string batchAccountName, AzureCloudConfig azureCloudConfig, CancellationToken cancellationToken) + public static Task TryGetBatchAccountInformationFromAccountNameAsync(string batchAccountName, Azure.Core.TokenCredential tokenCredential, ArmEnvironment armEnvironment, CancellationToken cancellationToken) { - //TODO: look if a newer version of the management SDK provides a simpler way to look for this information . + ArgumentException.ThrowIfNullOrEmpty(batchAccountName); + return GetAzureResourceAsync( - clientFactory: (tokenCredentials, subscription) => new BatchManagementClient(tokenCredentials) { SubscriptionId = subscription, BaseUri = new Uri(azureCloudConfig.ResourceManagerUrl) }, - listAsync: (client, ct) => client.BatchAccount.ListAsync(ct), - listNextAsync: (client, link, ct) => client.BatchAccount.ListNextAsync(link, ct), + tokenCredential, armEnvironment, + listAsync: Azure.ResourceManager.Batch.BatchExtensions.GetBatchAccountsAsync, + getDataAsync: async (subscriptionResource, token) => await subscriptionResource.GetAsync(token), + getData: subscriptionResource => subscriptionResource.Data, predicate: a => a.Name.Equals(batchAccountName, StringComparison.OrdinalIgnoreCase), - azureCloudConfig: azureCloudConfig, cancellationToken: cancellationToken, - finalize: batchAccount => BatchAccountResourceInformation.FromBatchResourceId(batchAccount.Id, batchAccount.Location, $"https://{batchAccount.AccountEndpoint}")); - } - - private static async Task GetAzureAccessTokenAsync(AzureCloudConfig azureCloudConfig, CancellationToken cancellationToken = default) - { - var defaultCredential = new DefaultAzureCredential(new DefaultAzureCredentialOptions { AuthorityHost = new Uri(azureCloudConfig.Authentication.LoginEndpointUrl) }); - var accessToken = await defaultCredential.GetTokenAsync(new Azure.Core.TokenRequestContext([azureCloudConfig.DefaultTokenScope]), cancellationToken); - return accessToken.Token; + finalize: batchAccount => BatchAccountResourceInformation.FromBatchResourceId(batchAccount.Id.ToString(), batchAccount.Location?.Name, $"{Uri.UriSchemeHttps}://{batchAccount.AccountEndpoint}")); } /// /// Looks up an Azure resource with management clients that use enumerators /// /// Value to return - /// Type of Azure management client to use to locate resources of type /// Type of Azure resource to enumerate/locate - /// Returns management client appropriate for enumerating resources of . A and the SubscriptionId are passed to this method as parameters. - /// ListAsync method from operational parameter on . Parameters are the returned by and . - /// ListNextAsync method from operational parameter on . Parameters are the returned by , the from the previous server call, and . - /// Returns true when the desired is found. - /// - /// - /// Converts to . Required if is not . + /// Type of Azure resource data + /// A credential capable of providing an OAuth token. + /// The information of an Azure Cloud environment. + /// Get{TResource}sAsync-style extension method with this parameter of type and one other parameter . + /// GetAsync-style extension method with this parameter of and one parameter . + /// Accessor to the Data property of , expected to return a . Exists to avoid using reflection. + /// Returns true when the desired is found. + /// A for controlling the lifetime of the asynchronous operation. + /// Converts to . Required if is not . /// The derived from the first that satisfies the condition in , else default. - private static async Task GetAzureResourceAsync( - Func clientFactory, - Func>> listAsync, - Func>> listNextAsync, - Predicate predicate, - AzureCloudConfig azureCloudConfig, + private static async Task GetAzureResourceAsync( + Azure.Core.TokenCredential tokenCredential, + ArmEnvironment armEnvironment, + Func> listAsync, + Func>> getDataAsync, + Func getData, + Predicate predicate, CancellationToken cancellationToken, - Func finalize = default) - where TAzManagementClient : Microsoft.Rest.Azure.IAzureClient, IDisposable + Func finalize = default) + where TResource : ArmResource { if (typeof(TResult) == typeof(TResource)) { finalize ??= new(a => (TResult)Convert.ChangeType(a, typeof(TResult))); } - ArgumentNullException.ThrowIfNull(clientFactory); + ArgumentNullException.ThrowIfNull(tokenCredential); + ArgumentNullException.ThrowIfNull(armEnvironment); ArgumentNullException.ThrowIfNull(listAsync); - ArgumentNullException.ThrowIfNull(listNextAsync); + ArgumentNullException.ThrowIfNull(getDataAsync); + ArgumentNullException.ThrowIfNull(getData); ArgumentNullException.ThrowIfNull(predicate); ArgumentNullException.ThrowIfNull(finalize); - var tokenCredentials = new TokenCredentials(await GetAzureAccessTokenAsync(azureCloudConfig, cancellationToken)); - var azureManagementClient = await AzureManagementClientsFactory.GetAzureManagementClientAsync(azureCloudConfig, cancellationToken); - - var subscriptions = (await azureManagementClient.Subscriptions.ListAsync(cancellationToken: cancellationToken)).ToAsyncEnumerable().Select(s => s.SubscriptionId); + var armClient = new ArmClient(tokenCredential, null, new ArmClientOptions { Environment = armEnvironment }); - await foreach (var subId in subscriptions.WithCancellation(cancellationToken)) + await foreach (var subResource in armClient.GetSubscriptions().SelectAwaitWithCancellation(async (sub, token) => (await sub.GetAsync(token)).Value).WithCancellation(CancellationToken.None)) { - using var client = clientFactory(tokenCredentials, subId); - - var item = await (await listAsync(client, cancellationToken)) - .ToAsyncEnumerable((page, ct) => listNextAsync(client, page, ct)) + var item = await listAsync(subResource, cancellationToken) + .SelectAwaitWithCancellation(async (subscriptionResource, token) => await getDataAsync(subscriptionResource, token)) + .Select(response => getData(response.Value)) .FirstOrDefaultAsync(a => predicate(a), cancellationToken); if (item is not null) diff --git a/src/TesApi.Web/Management/AzureManagementClientsFactory.cs b/src/TesApi.Web/Management/AzureManagementClientsFactory.cs index dcec9780a..ae1a1f37e 100644 --- a/src/TesApi.Web/Management/AzureManagementClientsFactory.cs +++ b/src/TesApi.Web/Management/AzureManagementClientsFactory.cs @@ -2,25 +2,22 @@ // Licensed under the MIT License. using System; -using System.Threading; -using System.Threading.Tasks; -using Azure.Identity; +using Azure.ResourceManager; +using Azure.ResourceManager.Batch; +using CommonUtilities; using CommonUtilities.AzureCloud; -using Microsoft.Azure.Management.Batch; -using Microsoft.Azure.Management.ResourceManager.Fluent.Authentication; -using Microsoft.Rest; -using FluentAzure = Microsoft.Azure.Management.Fluent.Azure; namespace TesApi.Web.Management { /// - /// Factory if ARM management clients. + /// Factory of ARM management clients. /// public class AzureManagementClientsFactory { private readonly BatchAccountResourceInformation batchAccountInformation; private readonly AzureCloudConfig azureCloudConfig; + private readonly AzureServicesConnectionStringCredentialOptions credentialOptions; /// /// Batch account resource information. @@ -32,11 +29,13 @@ public class AzureManagementClientsFactory /// /// > /// + /// /// - public AzureManagementClientsFactory(BatchAccountResourceInformation batchAccountInformation, AzureCloudConfig azureCloudConfig) + public AzureManagementClientsFactory(BatchAccountResourceInformation batchAccountInformation, AzureCloudConfig azureCloudConfig, AzureServicesConnectionStringCredentialOptions credentialOptions) { ArgumentNullException.ThrowIfNull(batchAccountInformation); ArgumentNullException.ThrowIfNull(azureCloudConfig); + ArgumentNullException.ThrowIfNull(credentialOptions); if (string.IsNullOrEmpty(batchAccountInformation.SubscriptionId)) { @@ -48,8 +47,11 @@ public AzureManagementClientsFactory(BatchAccountResourceInformation batchAccoun throw new ArgumentException("Batch account information does not contain the resource group name.", nameof(batchAccountInformation)); } + credentialOptions.AuthorityHost = azureCloudConfig.AuthorityHost; + this.batchAccountInformation = batchAccountInformation; this.azureCloudConfig = azureCloudConfig; + this.credentialOptions = credentialOptions; } /// @@ -57,39 +59,16 @@ public AzureManagementClientsFactory(BatchAccountResourceInformation batchAccoun /// protected AzureManagementClientsFactory() { } - private static async Task GetAzureAccessTokenAsync(CancellationToken cancellationToken, string authorityHost, string scope = "https://management.azure.com//.default") - => (await new DefaultAzureCredential(new DefaultAzureCredentialOptions { AuthorityHost = new Uri(authorityHost) }).GetTokenAsync(new Azure.Core.TokenRequestContext([scope]), cancellationToken)).Token; - /// /// Creates Batch Account management client using AAD authentication. /// Configure to the subscription id that contains the batch account. /// - /// A for controlling the lifetime of the asynchronous operation. - /// - public async Task CreateBatchAccountManagementClient(CancellationToken cancellationToken) - => new BatchManagementClient(new Uri(azureCloudConfig.ResourceManagerUrl), new TokenCredentials(await GetAzureAccessTokenAsync(cancellationToken, authorityHost: azureCloudConfig.Authentication.LoginEndpointUrl, scope: azureCloudConfig.DefaultTokenScope))) { SubscriptionId = batchAccountInformation.SubscriptionId }; - - /// - /// Creates a new instance of Azure Management Client with the default credentials and subscription. - /// - /// A for controlling the lifetime of the asynchronous operation. - /// - public async Task CreateAzureManagementClientAsync(CancellationToken cancellationToken) - => await GetAzureManagementClientAsync(azureCloudConfig, cancellationToken); - - /// - /// Creates a new instance of Azure Management client - /// - /// - /// A for controlling the lifetime of the asynchronous operation. /// - public static async Task GetAzureManagementClientAsync(AzureCloudConfig azureCloudConfig, CancellationToken cancellationToken) - { - var accessToken = await GetAzureAccessTokenAsync(cancellationToken, authorityHost: azureCloudConfig.Authentication.LoginEndpointUrl, scope: azureCloudConfig.DefaultTokenScope); - var azureCredentials = new AzureCredentials(new TokenCredentials(accessToken), null, null, azureCloudConfig.AzureEnvironment); - var azureClient = FluentAzure.Authenticate(azureCredentials); - - return azureClient; - } + public BatchAccountResource CreateBatchAccountManagementClient() + => new ArmClient( + new AzureServicesConnectionStringCredential(credentialOptions), + batchAccountInformation.SubscriptionId, + new() { Environment = azureCloudConfig.ArmEnvironment }) + .GetBatchAccountResource(BatchAccountResource.CreateResourceIdentifier(batchAccountInformation.SubscriptionId, batchAccountInformation.ResourceGroupName, batchAccountInformation.Name)); } } diff --git a/src/TesApi.Web/Management/Batch/ArmBatchPoolManager.cs b/src/TesApi.Web/Management/Batch/ArmBatchPoolManager.cs index 195d26281..661877200 100644 --- a/src/TesApi.Web/Management/Batch/ArmBatchPoolManager.cs +++ b/src/TesApi.Web/Management/Batch/ArmBatchPoolManager.cs @@ -2,11 +2,10 @@ // Licensed under the MIT License. using System; +using System.Linq; using System.Threading; using System.Threading.Tasks; -using Microsoft.Azure.Batch; -using Microsoft.Azure.Management.Batch; -using Microsoft.Azure.Management.Batch.Models; +using Azure.ResourceManager.Batch; using Microsoft.Extensions.Logging; namespace TesApi.Web.Management.Batch @@ -17,7 +16,7 @@ namespace TesApi.Web.Management.Batch public class ArmBatchPoolManager : IBatchPoolManager { - private readonly ILogger logger; + private readonly ILogger logger; private readonly AzureManagementClientsFactory azureClientsFactory; /// @@ -36,24 +35,28 @@ public ArmBatchPoolManager(AzureManagementClientsFactory azureClientsFactory, } /// - public async Task CreateBatchPoolAsync(Pool poolSpec, bool isPreemptable, CancellationToken cancellationToken) + public async Task CreateBatchPoolAsync(BatchAccountPoolData poolSpec, bool isPreemptable, CancellationToken cancellationToken) { + var nameItem = poolSpec.Metadata.Single(i => string.IsNullOrEmpty(i.Name)); + try { - var batchManagementClient = await azureClientsFactory.CreateBatchAccountManagementClient(cancellationToken); + poolSpec.Metadata.Remove(nameItem); + + var batchManagementClient = azureClientsFactory.CreateBatchAccountManagementClient(); - logger.LogInformation("Creating manual batch pool named {PoolName} with vmSize {PoolVmSize} and low priority {IsPreemptable}", poolSpec.Name, poolSpec.VmSize, isPreemptable); + logger.LogInformation("Creating manual batch pool named {PoolName} with vmSize {PoolVmSize} and low priority {IsPreemptable}", nameItem.Value, poolSpec.VmSize, isPreemptable); - var pool = await batchManagementClient.Pool.CreateAsync(azureClientsFactory.BatchAccountInformation.ResourceGroupName, azureClientsFactory.BatchAccountInformation.Name, poolSpec.Name, poolSpec, cancellationToken: cancellationToken); + _ = await batchManagementClient.GetBatchAccountPools().CreateOrUpdateAsync(Azure.WaitUntil.Completed, nameItem.Value, poolSpec, cancellationToken: cancellationToken); - logger.LogInformation("Successfully created manual batch pool named {PoolName} with vmSize {PoolVmSize} and low priority {IsPreemptable}", poolSpec.Name, poolSpec.VmSize, isPreemptable); + logger.LogInformation("Successfully created manual batch pool named {PoolName} with vmSize {PoolVmSize} and low priority {IsPreemptable}", nameItem.Value, poolSpec.VmSize, isPreemptable); - return pool.Name; + return nameItem.Value; } catch (Exception exc) { var batchError = Newtonsoft.Json.JsonConvert.SerializeObject((exc as Microsoft.Azure.Batch.Common.BatchException)?.RequestInformation?.BatchError); - logger.LogError(exc, "Error trying to create manual batch pool named {PoolName} with vmSize {PoolVmSize} and low priority {IsPreemptable}. Batch error: {BatchError}", poolSpec.Name, poolSpec.VmSize, isPreemptable, batchError); + logger.LogError(exc, "Error trying to create manual batch pool named {PoolName} with vmSize {PoolVmSize} and low priority {IsPreemptable}. Batch error: {BatchError}", nameItem.Value, poolSpec.VmSize, isPreemptable, batchError); throw; } } @@ -63,18 +66,18 @@ public async Task DeleteBatchPoolAsync(string poolId, CancellationToken cancella { try { - var batchManagementClient = await azureClientsFactory.CreateBatchAccountManagementClient(cancellationToken); + var batchManagementClient = azureClientsFactory.CreateBatchAccountManagementClient(); logger.LogInformation( @"Deleting pool with the id/name:{PoolName} in Batch account:{BatchAccountName}", poolId, azureClientsFactory.BatchAccountInformation.Name); - _ = await batchManagementClient.Pool.DeleteAsync( - azureClientsFactory.BatchAccountInformation.ResourceGroupName, - azureClientsFactory.BatchAccountInformation.Name, poolId, cancellationToken: cancellationToken); + _ = await batchManagementClient.GetBatchAccountPools().Get(BatchAccountPoolResource.CreateResourceIdentifier( + azureClientsFactory.BatchAccountInformation.SubscriptionId, azureClientsFactory.BatchAccountInformation.ResourceGroupName, + azureClientsFactory.BatchAccountInformation.Name, poolId), cancellationToken: cancellationToken).Value + .DeleteAsync(Azure.WaitUntil.Completed, cancellationToken); logger.LogInformation( @"Successfully deleted pool with the id/name:{PoolName} in Batch account:{BatchAccountName}", poolId, azureClientsFactory.BatchAccountInformation.Name); - } catch (Exception e) { diff --git a/src/TesApi.Web/Management/Batch/CachingWithRetriesBatchPoolManager.cs b/src/TesApi.Web/Management/Batch/CachingWithRetriesBatchPoolManager.cs new file mode 100644 index 000000000..c43deffbc --- /dev/null +++ b/src/TesApi.Web/Management/Batch/CachingWithRetriesBatchPoolManager.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Azure.ResourceManager.Batch; +using Microsoft.Azure.Batch.Common; +using Microsoft.Extensions.Logging; +using Tes.ApiClients; + +namespace TesApi.Web.Management.Batch +{ + /// + /// Implements caching and retries for . + /// + public class CachingWithRetriesBatchPoolManager : CachingWithRetriesBase, IBatchPoolManager + { + private readonly IBatchPoolManager batchPoolManager; + + /// + /// Constructor to create a cache of + /// + /// + /// + /// + public CachingWithRetriesBatchPoolManager(IBatchPoolManager batchPoolManager, CachingRetryPolicyBuilder cachingRetryHandler, ILogger logger) + : base(cachingRetryHandler, logger) + { + ArgumentNullException.ThrowIfNull(batchPoolManager); + ArgumentNullException.ThrowIfNull(cachingRetryHandler); + + this.batchPoolManager = batchPoolManager; + } + + + /// + async Task IBatchPoolManager.DeleteBatchPoolAsync(string poolId, CancellationToken cancellationToken) + { + try + { + await cachingAsyncRetryExceptWhenNotFound.ExecuteWithRetryAsync(ct => batchPoolManager.DeleteBatchPoolAsync(poolId, ct), cancellationToken); + } + catch (BatchException exc) when (BatchErrorCodeStrings.TaskNotFound.Equals(exc.RequestInformation?.BatchError?.Code, StringComparison.OrdinalIgnoreCase)) + { } + } + + /// + async Task IBatchPoolManager.CreateBatchPoolAsync(BatchAccountPoolData poolSpec, bool isPreemptable, CancellationToken cancellationToken) + { + try + { + return await cachingAsyncRetryExceptWhenExists.ExecuteWithRetryAsync(ct => batchPoolManager.CreateBatchPoolAsync(poolSpec, isPreemptable, ct), cancellationToken); + } + catch (BatchException exc) when (BatchErrorCodeStrings.PoolExists.Equals(exc.RequestInformation?.BatchError?.Code, StringComparison.OrdinalIgnoreCase)) + { + return poolSpec.Name; + } + } + } +} diff --git a/src/TesApi.Web/Management/Batch/IBatchPoolManager.cs b/src/TesApi.Web/Management/Batch/IBatchPoolManager.cs index 048e05912..2c17ca352 100644 --- a/src/TesApi.Web/Management/Batch/IBatchPoolManager.cs +++ b/src/TesApi.Web/Management/Batch/IBatchPoolManager.cs @@ -3,8 +3,8 @@ using System.Threading; using System.Threading.Tasks; +using Azure.ResourceManager.Batch; using Microsoft.Azure.Batch; -using BatchModels = Microsoft.Azure.Management.Batch.Models; namespace TesApi.Web.Management.Batch { @@ -16,19 +16,17 @@ public interface IBatchPoolManager /// /// Creates an Azure Batch pool who's lifecycle must be manually managed /// - /// Contains information about the pool. becomes the + /// Contains the specification for the pool. /// True if nodes in this pool will all be preemptable. False if nodes will all be dedicated. - /// The name (aka ) that identifies the created pool. - /// - Task CreateBatchPoolAsync(BatchModels.Pool poolSpec, bool isPreemptable, CancellationToken cancellationToken); + /// A for controlling the lifetime of the asynchronous operation. + /// (from ) becomes the (aka ). + Task CreateBatchPoolAsync(BatchAccountPoolData poolSpec, bool isPreemptable, CancellationToken cancellationToken); /// /// Deletes the specified pool /// /// The id of the pool. - /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. - /// - Task DeleteBatchPoolAsync(string poolId, CancellationToken cancellationToken = default); - + /// A for controlling the lifetime of the asynchronous operation. + Task DeleteBatchPoolAsync(string poolId, CancellationToken cancellationToken); } } diff --git a/src/TesApi.Web/Management/Batch/MappingProfilePoolToWsmRequest.cs b/src/TesApi.Web/Management/Batch/MappingProfilePoolToWsmRequest.cs index ad51045db..370bc1567 100644 --- a/src/TesApi.Web/Management/Batch/MappingProfilePoolToWsmRequest.cs +++ b/src/TesApi.Web/Management/Batch/MappingProfilePoolToWsmRequest.cs @@ -4,15 +4,16 @@ using System; using System.Linq; using AutoMapper; -using Microsoft.Azure.Management.Batch.Models; -using Microsoft.Azure.Management.ContainerRegistry.Fluent.Models; +using Azure.Core; +using Azure.ResourceManager.Batch; +using Azure.ResourceManager.Batch.Models; using TesApi.Web.Management.Models.Terra; using TesApi.Web.Runner; namespace TesApi.Web.Management.Batch { /// - /// Automapper mapping profile for Batch Pool to Wsm Create Batch API + /// AutoMapper mapping profile for Batch Pool to Wsm Create Batch API /// public class MappingProfilePoolToWsmRequest : Profile { @@ -21,37 +22,40 @@ public class MappingProfilePoolToWsmRequest : Profile /// public MappingProfilePoolToWsmRequest() { - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap() - .ForMember(dest => dest.EvaluationInterval, opt => opt.MapFrom(src => Convert.ToInt64(src.EvaluationInterval.Value.TotalMinutes))); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - CreateMap(); - //TODO: This mapping to be updated once the WSM API changes to support the correct values - CreateMap() - .ForMember(dest => dest.ClientId, opt => opt.MapFrom(src => src.ClientId)) - .ForMember(dest => dest.ResourceGroupName, opt => opt.Ignore()) - .ForMember(dest => dest.Name, opt => opt.Ignore()); - CreateMap() - .ForMember(dest => dest.Id, opt => opt.MapFrom(src => src.Name)) - .ForMember(dest => dest.UserAssignedIdentities, opt => opt.MapFrom(src => src.Identity.UserAssignedIdentities.Select(kvp => new ApiUserAssignedIdentity() { Name = TryGetManagedIdentityNameFromResourceId(kvp.Key), ClientId = kvp.Value.ClientId }))); + CreateMap() + .ForMember(dst => dst.VirtualMachineConfiguration, opt => opt.MapFrom(src => src.VmConfiguration)) + .ForMember(dst => dst.CloudServiceConfiguration, opt => opt.MapFrom(src => src.CloudServiceConfiguration)); + CreateMap(); + CreateMap(); + CreateMap(); + CreateMap() + .ForMember(dst => dst.EvaluationInterval, opt => opt.MapFrom(src => Convert.ToInt64(src.EvaluationInterval.Value.TotalMinutes))); + CreateMap(); + CreateMap() + .ForMember(dst => dst.AutoStorageContainerName, opt => opt.MapFrom(src => src.AutoBlobContainerName)) + .ForMember(dst => dst.HttpUrl, opt => opt.MapFrom(src => src.HttpUri)) + .ForMember(dst => dst.IdentityReference, opt => opt.MapFrom(src => src.IdentityResourceId)) + .ForMember(dst => dst.StorageContainerUrl, opt => opt.MapFrom(src => src.BlobContainerUri)); + CreateMap() + .ForMember(dst => dst.ResourceId, opt => opt.MapFrom(src => src.ToString())); + CreateMap(); + CreateMap(); + CreateMap(); + CreateMap(); + CreateMap(); + CreateMap() + .ForMember(dst => dst.IdentityReference, opt => opt.MapFrom(src => src.IdentityResourceId)); + CreateMap() + .ForPath(dst => dst.EndpointConfiguration.InboundNatPools, opt => opt.MapFrom(src => src.EndpointInboundNatPools)); + CreateMap(); + CreateMap(); + CreateMap(); + CreateMap(); + CreateMap(); + CreateMap(); + CreateMap() + .ForMember(dst => dst.Id, opt => opt.MapFrom(src => src.Name)) + .ForMember(dst => dst.UserAssignedIdentities, opt => opt.MapFrom(src => src.Identity.UserAssignedIdentities.Select(kvp => new ApiUserAssignedIdentity() { Name = TryGetManagedIdentityNameFromResourceId(kvp.Key) }))); } /// diff --git a/src/TesApi.Web/Management/Batch/PoolMetadataReader.cs b/src/TesApi.Web/Management/Batch/PoolMetadataReader.cs index 2222c1b3e..5109009a7 100644 --- a/src/TesApi.Web/Management/Batch/PoolMetadataReader.cs +++ b/src/TesApi.Web/Management/Batch/PoolMetadataReader.cs @@ -3,8 +3,9 @@ using System; using System.Linq; +using System.Threading; +using System.Threading.Tasks; using Microsoft.Azure.Batch; -using Microsoft.Azure.Batch.Auth; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using TesApi.Web.Management.Configuration; @@ -17,10 +18,9 @@ namespace TesApi.Web.Management.Batch /// public class PoolMetadataReader { - private readonly ILogger logger; + private readonly ILogger logger; private readonly TerraOptions terraOptions; - private readonly BatchAccountOptions batchAccountOptions; - private readonly BatchClient batchClient; + private readonly IAzureProxy azureProxy; /// /// Parameter-less constructor of PoolMetadataReader @@ -30,23 +30,21 @@ protected PoolMetadataReader() { } /// /// Constructor of PoolMetadataReader /// - /// /// + /// /// > - public PoolMetadataReader(IOptions batchAccountOptions, IOptions terraOptions, ILogger logger) + public PoolMetadataReader(IOptions terraOptions, IAzureProxy azureProxy, ILogger logger) { - ArgumentNullException.ThrowIfNull(batchAccountOptions); + ArgumentNullException.ThrowIfNull(azureProxy); ArgumentNullException.ThrowIfNull(terraOptions); ArgumentNullException.ThrowIfNull(logger); - this.batchAccountOptions = batchAccountOptions.Value; this.terraOptions = terraOptions.Value; + this.azureProxy = azureProxy; this.logger = logger; ValidateOptions(); - - batchClient = CreateBatchClientFromOptions(); } /// @@ -54,14 +52,14 @@ public PoolMetadataReader(IOptions batchAccountOptions, IOp /// /// Pool id /// Metadata key + /// /// /// When pool is not found - public virtual string GetMetadataValue(string poolId, string key) + public virtual async ValueTask GetMetadataValueAsync(string poolId, string key, CancellationToken cancellationToken) { + logger.LogInformation(@"Getting metadata from pool {PoolId}. Key {MetadataKey}", poolId, key); - logger.LogInformation($"Getting metadata from pool {poolId}. Key {key}"); - - var poolMetadata = batchClient.PoolOperations.GetPool(poolId)?.Metadata; + var poolMetadata = (await azureProxy.GetBatchPoolAsync(poolId, cancellationToken: cancellationToken, new ODATADetailLevel { SelectClause = "metadata" }))?.Metadata; if (poolMetadata is null) { @@ -71,19 +69,9 @@ public virtual string GetMetadataValue(string poolId, string key) return poolMetadata.SingleOrDefault(m => m.Name.Equals(key))?.Value; } - private BatchClient CreateBatchClientFromOptions() - { - return BatchClient.Open(new BatchSharedKeyCredentials(batchAccountOptions.BaseUrl, - batchAccountOptions.AccountName, batchAccountOptions.AppKey)); - } - private void ValidateOptions() { ArgumentException.ThrowIfNullOrEmpty(terraOptions.WorkspaceId, nameof(terraOptions.WorkspaceId)); - ArgumentException.ThrowIfNullOrEmpty(batchAccountOptions.AccountName, nameof(batchAccountOptions.AccountName)); - ArgumentException.ThrowIfNullOrEmpty(batchAccountOptions.AppKey, nameof(batchAccountOptions.AppKey)); - ArgumentException.ThrowIfNullOrEmpty(batchAccountOptions.BaseUrl, nameof(batchAccountOptions.BaseUrl)); - ArgumentException.ThrowIfNullOrEmpty(batchAccountOptions.ResourceGroup, nameof(batchAccountOptions.ResourceGroup)); } } } diff --git a/src/TesApi.Web/Management/Batch/TerraBatchPoolManager.cs b/src/TesApi.Web/Management/Batch/TerraBatchPoolManager.cs index d51814338..6a1d9ff6f 100644 --- a/src/TesApi.Web/Management/Batch/TerraBatchPoolManager.cs +++ b/src/TesApi.Web/Management/Batch/TerraBatchPoolManager.cs @@ -6,9 +6,7 @@ using System.Threading; using System.Threading.Tasks; using AutoMapper; -using Microsoft.Azure.Batch; -using Microsoft.Azure.Batch.Auth; -using Microsoft.Azure.Management.Batch.Models; +using Azure.ResourceManager.Batch; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Tes.ApiClients; @@ -18,7 +16,7 @@ namespace TesApi.Web.Management.Batch { /// - /// + /// Provides management plane operations for Azure Batch Pools using Terra /// public class TerraBatchPoolManager : IBatchPoolManager { @@ -35,28 +33,28 @@ public class TerraBatchPoolManager : IBatchPoolManager private readonly IMapper mapper; private readonly ILogger logger; private readonly TerraOptions terraOptions; - private readonly BatchAccountOptions batchAccountOptions; + private readonly PoolMetadataReader poolMetadataReader; /// /// Provides batch pool created and delete operations via the Terra api. /// /// /// - /// + /// /// /// - public TerraBatchPoolManager(TerraWsmApiClient terraWsmApiClient, IMapper mapper, IOptions terraOptions, IOptions batchAccountOptions, ILogger logger) + public TerraBatchPoolManager(TerraWsmApiClient terraWsmApiClient, IMapper mapper, PoolMetadataReader poolMetadata, IOptions terraOptions, ILogger logger) { ArgumentNullException.ThrowIfNull(terraWsmApiClient); ArgumentNullException.ThrowIfNull(mapper); + ArgumentNullException.ThrowIfNull(poolMetadata); ArgumentNullException.ThrowIfNull(logger); ArgumentNullException.ThrowIfNull(terraOptions); - ArgumentNullException.ThrowIfNull(batchAccountOptions); this.terraWsmApiClient = terraWsmApiClient; this.mapper = mapper; + this.poolMetadataReader = poolMetadata; this.logger = logger; - this.batchAccountOptions = batchAccountOptions.Value; this.terraOptions = terraOptions.Value; ValidateOptions(); @@ -70,12 +68,14 @@ public TerraBatchPoolManager(TerraWsmApiClient terraWsmApiClient, IMapper mapper /// /// A for controlling the lifetime of the asynchronous operation. /// - public async Task CreateBatchPoolAsync(Pool poolSpec, bool isPreemptable, CancellationToken cancellationToken) + public async Task CreateBatchPoolAsync(BatchAccountPoolData poolSpec, bool isPreemptable, CancellationToken cancellationToken) { var resourceId = Guid.NewGuid(); var resourceName = $"TES-{resourceId}"; + var nameItem = poolSpec.Metadata.Single(i => string.IsNullOrEmpty(i.Name)); + poolSpec.Metadata.Remove(nameItem); - var apiRequest = new ApiCreateBatchPoolRequest() + ApiCreateBatchPoolRequest apiRequest = new() { Common = new ApiCommon { @@ -89,7 +89,7 @@ public async Task CreateBatchPoolAsync(Pool poolSpec, bool isPreemptable AzureBatchPool = mapper.Map(poolSpec), }; - apiRequest.AzureBatchPool.Id = poolSpec.Name; + apiRequest.AzureBatchPool.Id = nameItem.Value; AddResourceIdToPoolMetadata(apiRequest, resourceId); @@ -100,11 +100,12 @@ public async Task CreateBatchPoolAsync(Pool poolSpec, bool isPreemptable private static void AddResourceIdToPoolMetadata(ApiCreateBatchPoolRequest apiRequest, Guid resourceId) { - var resourceIdMetadataItem = - new ApiBatchPoolMetadataItem() { Name = TerraResourceIdMetadataKey, Value = resourceId.ToString() }; + ApiBatchPoolMetadataItem resourceIdMetadataItem = new() + { Name = TerraResourceIdMetadataKey, Value = resourceId.ToString() }; + if (apiRequest.AzureBatchPool.Metadata is null) { - apiRequest.AzureBatchPool.Metadata = new ApiBatchPoolMetadataItem[] { resourceIdMetadataItem }; + apiRequest.AzureBatchPool.Metadata = [resourceIdMetadataItem]; return; } @@ -112,7 +113,7 @@ private static void AddResourceIdToPoolMetadata(ApiCreateBatchPoolRequest apiReq metadataList.Add(resourceIdMetadataItem); - apiRequest.AzureBatchPool.Metadata = metadataList.ToArray(); + apiRequest.AzureBatchPool.Metadata = [.. metadataList]; } /// @@ -127,18 +128,18 @@ public async Task DeleteBatchPoolAsync(string poolId, CancellationToken cancella try { logger.LogInformation( - $"Deleting pool with the ID/name: {poolId}"); + "Deleting pool with the ID/name: {PoolId}", poolId); var wsmResourceId = await GetWsmResourceIdFromBatchPoolMetadataAsync(poolId, cancellationToken); await terraWsmApiClient.DeleteBatchPoolAsync(Guid.Parse(terraOptions.WorkspaceId), wsmResourceId, cancellationToken); logger.LogInformation( - $"Successfully deleted pool with the ID/name via WSM: {poolId}"); + "Successfully deleted pool with the ID/name via WSM: {PoolId}", poolId); } catch (Exception e) { - logger.LogError(e, $"Error trying to delete pool named {poolId}"); + logger.LogError(e, "Error trying to delete pool named {PoolId}", poolId); throw; } @@ -146,37 +147,20 @@ public async Task DeleteBatchPoolAsync(string poolId, CancellationToken cancella private async Task GetWsmResourceIdFromBatchPoolMetadataAsync(string poolId, CancellationToken cancellationToken) { - var batchClient = CreateBatchClientFromOptions(); - - var pool = await batchClient.PoolOperations.GetPoolAsync(poolId, cancellationToken: cancellationToken); + var metadataItem = await poolMetadataReader.GetMetadataValueAsync(poolId, TerraResourceIdMetadataKey, cancellationToken); - if (pool is null) - { - throw new InvalidOperationException($"The Batch pool was not found. Pool ID: {poolId}"); - } - - var metadataItem = pool.Metadata.SingleOrDefault(m => m.Name.Equals(TerraResourceIdMetadataKey)); - - if (string.IsNullOrEmpty(metadataItem?.Value)) + if (string.IsNullOrEmpty(metadataItem)) { throw new InvalidOperationException("The WSM resource ID was not found in the pool's metadata."); } - var wsmResourceId = Guid.Parse(metadataItem.Value); + var wsmResourceId = Guid.Parse(metadataItem); return wsmResourceId; } - private BatchClient CreateBatchClientFromOptions() - => BatchClient.Open(new BatchSharedKeyCredentials(batchAccountOptions.BaseUrl, - batchAccountOptions.AccountName, batchAccountOptions.AppKey)); - private void ValidateOptions() { ArgumentException.ThrowIfNullOrEmpty(terraOptions.WorkspaceId, nameof(terraOptions.WorkspaceId)); - ArgumentException.ThrowIfNullOrEmpty(batchAccountOptions.AccountName, nameof(batchAccountOptions.AccountName)); - ArgumentException.ThrowIfNullOrEmpty(batchAccountOptions.AppKey, nameof(batchAccountOptions.AppKey)); - ArgumentException.ThrowIfNullOrEmpty(batchAccountOptions.BaseUrl, nameof(batchAccountOptions.BaseUrl)); - ArgumentException.ThrowIfNullOrEmpty(batchAccountOptions.ResourceGroup, nameof(batchAccountOptions.ResourceGroup)); } } } diff --git a/src/TesApi.Web/Management/BatchAccountUtilization.cs b/src/TesApi.Web/Management/BatchAccountUtilization.cs index 2ac444196..2e8a716b9 100644 --- a/src/TesApi.Web/Management/BatchAccountUtilization.cs +++ b/src/TesApi.Web/Management/BatchAccountUtilization.cs @@ -8,12 +8,8 @@ namespace TesApi.Web.Management /// /// Active job counts /// Active pool count - /// Total cores in use - /// Number of dedicated cores in requested Vm family public record BatchAccountUtilization( int ActiveJobsCount, - int ActivePoolsCount, - int TotalCoresInUse, - int DedicatedCoresInUseInRequestedVmFamily); + int ActivePoolsCount); } diff --git a/src/TesApi.Web/Management/BatchQuotaVerifier.cs b/src/TesApi.Web/Management/BatchQuotaVerifier.cs index 604c622eb..d321f5c1d 100644 --- a/src/TesApi.Web/Management/BatchQuotaVerifier.cs +++ b/src/TesApi.Web/Management/BatchQuotaVerifier.cs @@ -89,8 +89,6 @@ public async Task CheckBatchAccountQuotasAsync(VirtualMachineInformation virtual } var isDedicatedAndPerVmFamilyCoreQuotaEnforced = isDedicated && batchVmFamilyBatchQuotas.DedicatedCoreQuotaPerVmFamilyEnforced; - var batchUtilization = await GetBatchAccountUtilizationAsync(virtualMachineInformation, cancellationToken); - if (workflowCoresRequirement > batchVmFamilyBatchQuotas.TotalCoreQuota) { @@ -106,9 +104,11 @@ public async Task CheckBatchAccountQuotasAsync(VirtualMachineInformation virtual if (needPoolOrJobQuotaCheck) { + var batchUtilization = GetBatchAccountUtilization(); + if (batchUtilization.ActiveJobsCount + 1 > batchVmFamilyBatchQuotas.ActiveJobAndJobScheduleQuota) { - throw new AzureBatchQuotaMaxedOutException($"No remaining active jobs quota available. There are {batchUtilization.ActivePoolsCount} active jobs out of {batchVmFamilyBatchQuotas.ActiveJobAndJobScheduleQuota}."); + throw new AzureBatchQuotaMaxedOutException($"No remaining active jobs quota available. There are {batchUtilization.ActiveJobsCount} active jobs out of {batchVmFamilyBatchQuotas.ActiveJobAndJobScheduleQuota}."); } if (batchUtilization.ActivePoolsCount + 1 > batchVmFamilyBatchQuotas.PoolQuota) @@ -122,28 +122,12 @@ public async Task CheckBatchAccountQuotasAsync(VirtualMachineInformation virtual public IBatchQuotaProvider GetBatchQuotaProvider() => batchQuotaProvider; - private async Task GetBatchAccountUtilizationAsync(VirtualMachineInformation vmInfo, CancellationToken cancellationToken) + private BatchAccountUtilization GetBatchAccountUtilization() { - var isDedicated = !vmInfo.LowPriority; + // TODO: make these async var activeJobsCount = azureProxy.GetBatchActiveJobCount(); var activePoolsCount = azureProxy.GetBatchActivePoolCount(); - var activeNodeCountByVmSize = azureProxy.GetBatchActiveNodeCountByVmSize().ToList(); - var virtualMachineInfoList = await batchSkuInformationProvider.GetVmSizesAndPricesAsync(batchAccountInformation.Region, cancellationToken); - - var totalCoresInUse = activeNodeCountByVmSize - .Sum(x => - virtualMachineInfoList - .FirstOrDefault(vm => vm.VmSize.Equals(x.VirtualMachineSize, StringComparison.OrdinalIgnoreCase))? - .VCpusAvailable * (isDedicated ? x.DedicatedNodeCount : x.LowPriorityNodeCount)) ?? 0; - - var vmSizesInRequestedFamily = virtualMachineInfoList.Where(vm => String.Equals(vm.VmFamily, vmInfo.VmFamily, StringComparison.OrdinalIgnoreCase)).Select(vm => vm.VmSize).ToList(); - - var activeNodeCountByVmSizeInRequestedFamily = activeNodeCountByVmSize.Where(x => vmSizesInRequestedFamily.Contains(x.VirtualMachineSize, StringComparer.OrdinalIgnoreCase)); - - var dedicatedCoresInUseInRequestedVmFamily = activeNodeCountByVmSizeInRequestedFamily - .Sum(x => virtualMachineInfoList.FirstOrDefault(vm => vm.VmSize.Equals(x.VirtualMachineSize, StringComparison.OrdinalIgnoreCase))?.VCpusAvailable * x.DedicatedNodeCount) ?? 0; - - return new BatchAccountUtilization(activeJobsCount, activePoolsCount, totalCoresInUse, dedicatedCoresInUseInRequestedVmFamily); + return new(activeJobsCount, activePoolsCount); } } diff --git a/src/TesApi.Web/Management/Configuration/BatchAccountOptions.cs b/src/TesApi.Web/Management/Configuration/BatchAccountOptions.cs index 2aaa6d9a1..fd00d669d 100644 --- a/src/TesApi.Web/Management/Configuration/BatchAccountOptions.cs +++ b/src/TesApi.Web/Management/Configuration/BatchAccountOptions.cs @@ -13,30 +13,36 @@ public class BatchAccountOptions /// public const string SectionName = "BatchAccount"; + /// /// Account name. /// public string AccountName { get; set; } + /// /// Base URL. /// Required if AppKey is provided. /// public string BaseUrl { get; set; } + /// /// AppKey /// If not set ARM authentication is used. /// public string AppKey { get; set; } + /// /// Arm region where the batch account is located. /// Required if AppKey is provided. /// public string Region { get; set; } + /// /// Subscription Id of the batch account. /// Required if AppKey is provided. /// public string SubscriptionId { get; set; } + /// /// ResourceApiResponse group of the batch account. /// Required if AppKey is provided. diff --git a/src/TesApi.Web/Program.cs b/src/TesApi.Web/Program.cs index b07ee418e..dc27673e6 100644 --- a/src/TesApi.Web/Program.cs +++ b/src/TesApi.Web/Program.cs @@ -3,7 +3,10 @@ using System; using System.Reflection; +using System.Threading; using System.Threading.Tasks; +using Azure.ResourceManager; +using CommonUtilities; using CommonUtilities.AzureCloud; using Microsoft.AspNetCore; using Microsoft.AspNetCore.Hosting; @@ -63,14 +66,14 @@ public static IWebHostBuilder CreateWebHostBuilder(string[] args) var config = configBuilder.Build(); Startup.AzureCloudConfig = GetAzureCloudConfig(config); StorageUrlUtils.BlobEndpointHostNameSuffix = $".blob.{Startup.AzureCloudConfig.Suffixes.StorageSuffix}"; - applicationInsightsOptions = GetApplicationInsightsConnectionString(config); + applicationInsightsOptions = GetApplicationInsightsConnectionString(config, Startup.AzureCloudConfig.ArmEnvironment.Value, new AzureServicesConnectionStringCredential(new(config, Startup.AzureCloudConfig))); if (!string.IsNullOrEmpty(applicationInsightsOptions?.ConnectionString)) { configBuilder.AddApplicationInsightsSettings(applicationInsightsOptions.ConnectionString, developerMode: context.HostingEnvironment.IsDevelopment() ? true : null); } - static ApplicationInsightsOptions GetApplicationInsightsConnectionString(IConfiguration configuration) + static ApplicationInsightsOptions GetApplicationInsightsConnectionString(IConfiguration configuration, ArmEnvironment armEnvironment, Azure.Core.TokenCredential credential) { var applicationInsightsOptions = configuration.GetSection(Options.ApplicationInsightsOptions.SectionName).Get(); var applicationInsightsAccountName = applicationInsightsOptions?.AccountName; @@ -84,7 +87,7 @@ static ApplicationInsightsOptions GetApplicationInsightsConnectionString(IConfig if (string.IsNullOrWhiteSpace(applicationInsightsConnectionString)) { - applicationInsightsConnectionString = ArmResourceInformationFinder.GetAppInsightsConnectionStringAsync(applicationInsightsAccountName, Startup.AzureCloudConfig, System.Threading.CancellationToken.None).Result; + applicationInsightsConnectionString = ArmResourceInformationFinder.GetAppInsightsConnectionStringFromAccountNameAsync(applicationInsightsAccountName, credential, armEnvironment, CancellationToken.None).Result; } if (!string.IsNullOrWhiteSpace(applicationInsightsConnectionString)) @@ -116,8 +119,11 @@ static ApplicationInsightsOptions GetApplicationInsightsConnectionString(IConfig logging.AddConsole(); } - // Optional: Apply filters to configure LogLevel Trace or above is sent to - // ApplicationInsights for all categories. + // Optional: Apply filters to configure LogLevel + // Trace or above is sent to ApplicationInsights for all categories. + + // Additional filtering For category starting in "System", + // only Warning or above will be sent to Application Insights. logging.AddFilter("System", LogLevel.Warning); // Additional filtering For category starting in "Microsoft", @@ -146,7 +152,7 @@ static AzureCloudConfig GetAzureCloudConfig(IConfiguration configuration) var tesOptions = new GeneralOptions(); configuration.Bind(GeneralOptions.SectionName, tesOptions); Console.WriteLine($"tesOptions.AzureCloudName: {tesOptions.AzureCloudName}"); - return AzureCloudConfig.CreateAsync(tesOptions.AzureCloudName, tesOptions.AzureCloudMetadataUrlApiVersion).Result; + return AzureCloudConfig.FromKnownCloudNameAsync(cloudName: tesOptions.AzureCloudName, azureCloudMetadataUrlApiVersion: tesOptions.AzureCloudMetadataUrlApiVersion).Result; } } } diff --git a/src/TesApi.Web/Properties/launchSettings.json b/src/TesApi.Web/Properties/launchSettings.json index 498077eb6..bc34d2b2a 100644 --- a/src/TesApi.Web/Properties/launchSettings.json +++ b/src/TesApi.Web/Properties/launchSettings.json @@ -14,7 +14,8 @@ "launchBrowser": true, "launchUrl": "", "environmentVariables": { - "ASPNETCORE_ENVIRONMENT": "Development" + "ASPNETCORE_ENVIRONMENT": "Development", + "AzureServicesAuthConnectionString": "RunAs=Developer;DeveloperTool=AzureCLI" } }, "WebApplication1": { @@ -23,7 +24,8 @@ "launchUrl": "", "applicationUrl": "https://localhost:5001;http://localhost:5000", "environmentVariables": { - "ASPNETCORE_ENVIRONMENT": "Development" + "ASPNETCORE_ENVIRONMENT": "Development", + "AzureServicesAuthConnectionString": "RunAs=Developer;DeveloperTool=AzureCLI" } } } diff --git a/src/TesApi.Web/Startup.cs b/src/TesApi.Web/Startup.cs index 7ceb51138..901a7e449 100644 --- a/src/TesApi.Web/Startup.cs +++ b/src/TesApi.Web/Startup.cs @@ -8,7 +8,7 @@ using System.Reflection; using System.Threading; using Azure.Core; -using Azure.Identity; +using Azure.ResourceManager; using CommonUtilities; using CommonUtilities.AzureCloud; using CommonUtilities.Options; @@ -69,6 +69,12 @@ public void ConfigureServices(IServiceCollection services) services .AddSingleton(AzureCloudConfig) .AddSingleton(AzureCloudConfig.AzureEnvironmentConfig) + .AddSingleton(s => + { + var options = ActivatorUtilities.CreateInstance(s); + options.AuthorityHost = AzureCloudConfig.AuthorityHost; + return options; + }) .AddLogging() .AddApplicationInsightsTelemetry(configuration) .Configure(configuration.GetSection(GeneralOptions.SectionName)) @@ -91,7 +97,7 @@ public void ConfigureServices(IServiceCollection services) .AddTransient() .AddSingleton() .AddSingleton(CreateTerraApiClient) - .AddSingleton(CreateBatchPoolManagerFromConfiguration) + .AddSingleton(sp => ActivatorUtilities.CreateInstance(sp, CreateBatchPoolManagerFromConfiguration(sp))) .AddControllers(options => options.Filters.Add()) .AddNewtonsoftJson(opts => @@ -117,13 +123,10 @@ public void ConfigureServices(IServiceCollection services) .AddSingleton() .AddSingleton() .AddSingleton() - .AddSingleton(s => - { - return new DefaultAzureCredential( - new DefaultAzureCredentialOptions { AuthorityHost = new Uri(AzureCloudConfig.Authentication.LoginEndpointUrl) }); - }) + .AddSingleton() .AddSingleton() .AddSingleton() + .AddSingleton() .AddSingleton(c => { @@ -407,7 +410,7 @@ BatchAccountResourceInformation CreateBatchAccountResourceInformation(IServicePr if (string.IsNullOrWhiteSpace(options.Value.AppKey)) { //we are assuming Arm with MI/RBAC if no key is provided. Try to get info from the batch account. - var task = ArmResourceInformationFinder.TryGetResourceInformationFromAccountNameAsync(options.Value.AccountName, AzureCloudConfig, System.Threading.CancellationToken.None); + var task = ArmResourceInformationFinder.TryGetBatchAccountInformationFromAccountNameAsync(options.Value.AccountName, services.GetRequiredService(), AzureCloudConfig.ArmEnvironment.Value, CancellationToken.None); task.Wait(); if (task.Result is null) diff --git a/src/TesApi.Web/Storage/DefaultStorageAccessProvider.cs b/src/TesApi.Web/Storage/DefaultStorageAccessProvider.cs index ab9f5bd67..08f9e0d7a 100644 --- a/src/TesApi.Web/Storage/DefaultStorageAccessProvider.cs +++ b/src/TesApi.Web/Storage/DefaultStorageAccessProvider.cs @@ -8,12 +8,10 @@ using System.Threading.Tasks; using System.Web; using Azure.Storage.Blobs; +using Azure.Storage.Sas; using CommonUtilities; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using Microsoft.WindowsAzure.Storage; -using Microsoft.WindowsAzure.Storage.Auth; -using Microsoft.WindowsAzure.Storage.Blob; using Tes.Extensions; using Tes.Models; using TesApi.Web.Options; @@ -45,7 +43,7 @@ public DefaultStorageAccessProvider(ILogger logger this.storageOptions = storageOptions.Value; this.azureEnvironmentConfig = azureEnvironmentConfig; - externalStorageContainers = storageOptions.Value.ExternalStorageContainers?.Split(new[] { ',', ';', '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries) + externalStorageContainers = storageOptions.Value.ExternalStorageContainers?.Split([',', ';', '\r', '\n'], StringSplitOptions.RemoveEmptyEntries) .Select(uri => { if (StorageAccountUrlSegments.TryCreate(uri, out var s)) @@ -141,27 +139,37 @@ private async Task AddSasTokenAsync(StorageAccountUrl try { - var accountKey = await AzureProxy.GetStorageAccountKeyAsync(storageAccountInfo, cancellationToken); var resultPathSegments = new StorageAccountUrlSegments(storageAccountInfo.BlobEndpoint, pathSegments.ContainerName, pathSegments.BlobName); - var policy = new SharedAccessBlobPolicy { SharedAccessExpiryTime = DateTimeOffset.UtcNow.Add(sasTokenDuration ?? SasTokenDuration) }; + var sharedAccessExpiryTime = DateTimeOffset.UtcNow.Add(sasTokenDuration ?? SasTokenDuration); + BlobSasBuilder sasBuilder; if (pathSegments.IsContainer || getContainerSas) { - policy.Permissions = SharedAccessBlobPermissions.Add | SharedAccessBlobPermissions.Create | SharedAccessBlobPermissions.List | SharedAccessBlobPermissions.Read | SharedAccessBlobPermissions.Write; - var containerUri = new StorageAccountUrlSegments(storageAccountInfo.BlobEndpoint, pathSegments.ContainerName).ToUri(); - resultPathSegments.SasToken = new CloudBlobContainer(containerUri, new StorageCredentials(storageAccountInfo.Name, accountKey)).GetSharedAccessSignature(policy, null, SharedAccessProtocol.HttpsOnly, null); + sasBuilder = new(BlobContainerSasPermissions.Add | BlobContainerSasPermissions.Create | BlobContainerSasPermissions.List | BlobContainerSasPermissions.Read | BlobContainerSasPermissions.Write, sharedAccessExpiryTime) + { + BlobContainerName = pathSegments.ContainerName, + Resource = "b", + }; } else { - policy.Permissions = SharedAccessBlobPermissions.Read; - resultPathSegments.SasToken = new CloudBlob(resultPathSegments.ToUri(), new StorageCredentials(storageAccountInfo.Name, accountKey)).GetSharedAccessSignature(policy, null, null, SharedAccessProtocol.HttpsOnly, null); + sasBuilder = new(BlobContainerSasPermissions.Read, sharedAccessExpiryTime) + { + BlobContainerName = pathSegments.ContainerName, + BlobName = pathSegments.BlobName, + Resource = "c" + }; } + sasBuilder.Protocol = SasProtocol.Https; + var accountCredential = new Azure.Storage.StorageSharedKeyCredential(storageAccountInfo.Name, await AzureProxy.GetStorageAccountKeyAsync(storageAccountInfo, cancellationToken)); + resultPathSegments.SasToken = sasBuilder.ToSasQueryParameters(accountCredential).ToString(); + return resultPathSegments; } catch (Exception ex) { - Logger.LogError(ex, $"Could not get the key of storage account '{pathSegments.AccountName}'. Make sure that the TES app service has Contributor access to it."); + Logger.LogError(ex, "Could not get the key of storage account '{StorageAccount}'. Make sure that the TES app service has Contributor access to it.", pathSegments.AccountName); return null; } } diff --git a/src/TesApi.Web/Storage/StorageAccessProvider.cs b/src/TesApi.Web/Storage/StorageAccessProvider.cs index 30b3b5bbf..35377cc31 100644 --- a/src/TesApi.Web/Storage/StorageAccessProvider.cs +++ b/src/TesApi.Web/Storage/StorageAccessProvider.cs @@ -95,7 +95,17 @@ public async Task UploadBlobAsync(Uri blobAbsoluteUrl, string content, /// public async Task> GetBlobUrlsAsync(Uri blobVirtualDirectory, CancellationToken cancellationToken) { - return (await AzureProxy.ListBlobsAsync(blobVirtualDirectory, cancellationToken)).Select(b => b.Uri).ToList(); + Azure.Storage.Blobs.BlobUriBuilder blobBuilder = new(blobVirtualDirectory) { Sas = null }; + return (await AzureProxy.ListBlobsAsync(blobVirtualDirectory, cancellationToken)).Select(GetBlobUri).ToList(); + + Uri GetBlobUri(Azure.Storage.Blobs.Models.BlobItem blob) + { + // This implementation reuses the BlobUriBuilder in the parent method, so GetBlobUri cannot be called in parallel with the same instance of BlobUriBuilder. + // It is safe for concurrent instances of GetBlobUrlsAsync to run simultaneously, however. + // Refactor if the ListBlobsAsync enumeration is ever parallelized at the stage of calling this converter method. + blobBuilder.BlobName = blob.Name; + return blobBuilder.ToUri(); + } } /// diff --git a/src/TesApi.Web/TesApi.Web.csproj b/src/TesApi.Web/TesApi.Web.csproj index b9a269df2..2eb5dcd96 100644 --- a/src/TesApi.Web/TesApi.Web.csproj +++ b/src/TesApi.Web/TesApi.Web.csproj @@ -48,26 +48,18 @@ - - - - - + - - - - - - + + - + - + diff --git a/src/deploy-tes-on-azure.Tests/KubernetesManagerTests.cs b/src/deploy-tes-on-azure.Tests/KubernetesManagerTests.cs index 23ef74502..cd5a02e30 100644 --- a/src/deploy-tes-on-azure.Tests/KubernetesManagerTests.cs +++ b/src/deploy-tes-on-azure.Tests/KubernetesManagerTests.cs @@ -4,6 +4,7 @@ using System.Threading.Tasks; using CommonUtilities; using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; namespace TesDeployer.Tests { @@ -13,8 +14,7 @@ public class KubernetesManagerTests [TestMethod] public async Task ValuesTemplateSuccessfullyDeserializesTesdatabaseToYaml() { - var azureConfig = ExpensiveObjectTestUtility.AzureCloudConfig; - var manager = new KubernetesManager(null, null, azureConfig, System.Threading.CancellationToken.None); + var manager = new KubernetesManager(new(), ExpensiveObjectTestUtility.AzureCloudConfig, (_, _, _) => throw new System.NotImplementedException(), System.Threading.CancellationToken.None); var helmValues = await manager.GetHelmValuesAsync(@"./cromwell-on-azure/helm/values-template.yaml"); Assert.IsNotNull(helmValues.TesDatabase); } diff --git a/src/deploy-tes-on-azure.Tests/deploy-tes-on-azure.Tests.csproj b/src/deploy-tes-on-azure.Tests/deploy-tes-on-azure.Tests.csproj index 8dec8c8ee..b95604018 100644 --- a/src/deploy-tes-on-azure.Tests/deploy-tes-on-azure.Tests.csproj +++ b/src/deploy-tes-on-azure.Tests/deploy-tes-on-azure.Tests.csproj @@ -8,9 +8,10 @@ - - - + + + + diff --git a/src/deploy-tes-on-azure/Deployer.cs b/src/deploy-tes-on-azure/Deployer.cs index 534f8be3f..903cac9d0 100644 --- a/src/deploy-tes-on-azure/Deployer.cs +++ b/src/deploy-tes-on-azure/Deployer.cs @@ -3,50 +3,49 @@ using System; using System.Collections.Generic; +using System.Data; using System.Diagnostics; using System.IdentityModel.Tokens.Jwt; using System.IO; using System.Linq; +using System.Net; using System.Net.Http; using System.Net.WebSockets; using System.Security.Cryptography; -using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; +using Azure; using Azure.Core; using Azure.Identity; using Azure.ResourceManager; +using Azure.ResourceManager.ApplicationInsights; +using Azure.ResourceManager.ApplicationInsights.Models; +using Azure.ResourceManager.Authorization; +using Azure.ResourceManager.Batch; +using Azure.ResourceManager.Compute; using Azure.ResourceManager.ContainerService; using Azure.ResourceManager.ContainerService.Models; +using Azure.ResourceManager.KeyVault; +using Azure.ResourceManager.KeyVault.Models; using Azure.ResourceManager.ManagedServiceIdentities; using Azure.ResourceManager.Network; using Azure.ResourceManager.Network.Models; +using Azure.ResourceManager.OperationalInsights; +using Azure.ResourceManager.PostgreSql.FlexibleServers; +using Azure.ResourceManager.PostgreSql.FlexibleServers.Models; +using Azure.ResourceManager.PrivateDns; +using Azure.ResourceManager.ResourceGraph; using Azure.ResourceManager.Resources; +using Azure.ResourceManager.Resources.Models; +using Azure.ResourceManager.Storage; using Azure.Security.KeyVault.Secrets; -using Azure.Storage; using Azure.Storage.Blobs; +using Azure.Storage.Blobs.Specialized; using CommonUtilities; using CommonUtilities.AzureCloud; using k8s; -using Microsoft.Azure.Management.Batch; -using Microsoft.Azure.Management.Batch.Models; -using Microsoft.Azure.Management.Compute.Fluent; -using Microsoft.Azure.Management.ContainerRegistry.Fluent; -using Microsoft.Azure.Management.Fluent; -using Microsoft.Azure.Management.Graph.RBAC.Fluent; -using Microsoft.Azure.Management.KeyVault; -using Microsoft.Azure.Management.KeyVault.Fluent; -using Microsoft.Azure.Management.KeyVault.Models; -using Microsoft.Azure.Management.Msi.Fluent; -using Microsoft.Azure.Management.Network.Fluent; -using Microsoft.Azure.Management.PostgreSQL; -using Microsoft.Azure.Management.PrivateDns.Fluent; -using Microsoft.Azure.Management.ResourceGraph; -using Microsoft.Azure.Management.ResourceManager.Fluent; -using Microsoft.Azure.Management.ResourceManager.Fluent.Authentication; -using Microsoft.Azure.Management.ResourceManager.Fluent.Core; -using Microsoft.Azure.Management.Storage.Fluent; -using Microsoft.Azure.Services.AppAuthentication; +using Microsoft.EntityFrameworkCore; +using Microsoft.Graph; using Microsoft.Rest; using Newtonsoft.Json; using Polly; @@ -54,14 +53,8 @@ using Tes.Extensions; using Tes.Models; using Tes.SDK; -using static Microsoft.Azure.Management.PostgreSQL.FlexibleServers.DatabasesOperationsExtensions; -using static Microsoft.Azure.Management.PostgreSQL.FlexibleServers.ServersOperationsExtensions; -using static Microsoft.Azure.Management.PostgreSQL.ServersOperationsExtensions; -using static Microsoft.Azure.Management.ResourceManager.Fluent.Core.RestClient; -using FlexibleServer = Microsoft.Azure.Management.PostgreSQL.FlexibleServers; -using FlexibleServerModel = Microsoft.Azure.Management.PostgreSQL.FlexibleServers.Models; -using IResource = Microsoft.Azure.Management.ResourceManager.Fluent.Core.IResource; -using KeyVaultManagementClient = Microsoft.Azure.Management.KeyVault.KeyVaultManagementClient; +using Batch = Azure.ResourceManager.Batch.Models; +using Storage = Azure.ResourceManager.Storage.Models; namespace TesDeployer { @@ -76,13 +69,19 @@ public class Deployer(Configuration configuration) .Handle(azureException => (int)System.Net.HttpStatusCode.Conflict == azureException.Status && "OperationNotAllowed".Equals(azureException.ErrorCode)) - .WaitAndRetryAsync(30, retryAttempt => System.TimeSpan.FromSeconds(10)); + .WaitAndRetryAsync(30, retryAttempt => TimeSpan.FromSeconds(10)); private static readonly AsyncRetryPolicy generalRetryPolicy = Policy .Handle() - .WaitAndRetryAsync(3, retryAttempt => System.TimeSpan.FromSeconds(1)); + .WaitAndRetryAsync(3, retryAttempt => TimeSpan.FromSeconds(1)); - private static readonly System.TimeSpan longRetryWaitTime = System.TimeSpan.FromSeconds(15); + private static readonly TimeSpan longRetryWaitTime = TimeSpan.FromSeconds(15); + + /// + /// Grants full access to manage all resources, but does not allow you to assign roles in Azure RBAC, manage assignments in Azure Blueprints, or share image galleries. + /// + /// https://learn.microsoft.com/azure/role-based-access-control/built-in-roles#general 'Contributor' in table. + private static readonly ResourceIdentifier All_Role_Contributor = AuthorizationRoleDefinitionResource.CreateResourceIdentifier(string.Empty, new("b24988ac-6180-42a0-ab88-20f7382dd24c")); public const string ConfigurationContainerName = "configuration"; public const string TesInternalContainerName = "tes-internal"; @@ -111,22 +110,56 @@ public class Deployer(Configuration configuration) private readonly Dictionary> requiredResourceProviderFeatures = new() { - { "Microsoft.Compute", new() { "EncryptionAtHost" } } + { "Microsoft.Compute", new() { "EncryptionAtHost" } }, }; -#pragma warning disable CA1859 // Use concrete types when possible for improved performance - private ITokenProvider tokenProvider; -#pragma warning restore CA1859 // Use concrete types when possible for improved performance - private TokenCredentials tokenCredentials; - private IAzure azureSubscriptionClient { get; set; } - private Microsoft.Azure.Management.Fluent.Azure.IAuthenticated azureClient { get; set; } - private IResourceManager resourceManagerClient { get; set; } + + [System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1859:Use concrete types when possible for improved performance", Justification = "We are using the base type everywhere.")] + private TokenCredential tokenCredential { get; set; } + private SubscriptionResource armSubscription { get; set; } private ArmClient armClient { get; set; } - private AzureCredentials azureCredentials { get; set; } - private FlexibleServer.IPostgreSQLManagementClient postgreSqlFlexManagementClient { get; set; } - private IEnumerable subscriptionIds { get; set; } + private ResourceGroupResource resourceGroup { get; set; } + private CloudEnvironment cloudEnvironment { get; set; } + private IEnumerable subscriptionIds { get; set; } private bool isResourceGroupCreated { get; set; } private KubernetesManager kubernetesManager { get; set; } - internal static AzureCloudConfig azureCloudConfig { get; set; } + internal static AzureCloudConfig azureCloudConfig { get; private set; } + private static readonly System.Collections.Concurrent.ConcurrentDictionary storageKeys = []; + + private static async Task EnsureResourceDataAsync(T resource, Predicate HasData, Func>>> GetAsync, CancellationToken cancellationToken, Action OnAcquisition = null) where T : ArmResource + { + return HasData(resource) + ? resource + : await FetchResourceDataAsync(GetAsync(resource), cancellationToken, OnAcquisition); + } + + private static async Task FetchResourceDataAsync(Func>> GetAsync, CancellationToken cancellationToken, Action OnAcquisition = null) where T : ArmResource + { + ArgumentNullException.ThrowIfNull(GetAsync); + + var result = await GetAsync(cancellationToken); + OnAcquisition?.Invoke(result); + return result; + } + + private Azure.Storage.StorageSharedKeyCredential GetStorageSharedKeyCredential(StorageAccountData storageAccount) + { + return storageKeys.GetOrAdd(storageAccount.Id, id => + { + var key = armClient + .GetStorageAccountResource(storageAccount.Id) + .GetKeysAsync(cancellationToken: cts.Token) + .FirstOrDefaultAsync(cts.Token) + .AsTask().GetAwaiter().GetResult(); + return new(storageAccount.Name, key.Value); + }); + } + + private BlobClient GetBlobClient(StorageAccountData storageAccount, string containerName, string blobName) + { + return new(new BlobUriBuilder(storageAccount.PrimaryEndpoints.BlobUri) { BlobContainerName = containerName, BlobName = blobName }.ToUri(), + GetStorageSharedKeyCredential(storageAccount), + new() { Audience = storageAccount.PrimaryEndpoints.BlobUri.AbsoluteUri }); + } public async Task DeployAsync() { @@ -138,7 +171,8 @@ public async Task DeployAsync() await Execute($"Getting cloud configuration for {configuration.AzureCloudName}...", async () => { - azureCloudConfig = await AzureCloudConfig.CreateAsync(configuration.AzureCloudName); + azureCloudConfig = await AzureCloudConfig.FromKnownCloudNameAsync(cloudName: configuration.AzureCloudName, retryPolicyOptions: Microsoft.Extensions.Options.Options.Create(new())); + cloudEnvironment = new(azureCloudConfig.ArmEnvironment.Value, azureCloudConfig.AuthorityHost); }); await Execute("Validating command line arguments...", () => @@ -151,43 +185,39 @@ await Execute("Validating command line arguments...", () => await Execute("Connecting to Azure Services...", async () => { - tokenProvider = new RefreshableAzureServiceTokenProvider(azureCloudConfig.ResourceManagerUrl, null, azureCloudConfig.Authentication.LoginEndpointUrl); - tokenCredentials = new(tokenProvider); - azureCredentials = new(tokenCredentials, null, null, azureCloudConfig.AzureEnvironment); - azureClient = GetAzureClient(azureCredentials); - armClient = new ArmClient(new AzureCliCredential(), null, new ArmClientOptions { Environment = azureCloudConfig.ArmEnvironment }); - azureSubscriptionClient = azureClient.WithSubscription(configuration.SubscriptionId); - subscriptionIds = await (await azureClient.Subscriptions.ListAsync(cancellationToken: cts.Token)).ToAsyncEnumerable().Select(s => s.SubscriptionId).ToListAsync(cts.Token); - resourceManagerClient = GetResourceManagerClient(azureCredentials); - postgreSqlFlexManagementClient = new FlexibleServer.PostgreSQLManagementClient(azureCredentials) { SubscriptionId = configuration.SubscriptionId, BaseUri = new Uri(azureCloudConfig.ResourceManagerUrl), LongRunningOperationRetryTimeout = 1200 }; + tokenCredential = new AzureCliCredential(new() { AuthorityHost = cloudEnvironment.AzureAuthorityHost }); + armClient = new ArmClient(tokenCredential, configuration.SubscriptionId, new() { Environment = cloudEnvironment.ArmEnvironment }); + armSubscription = armClient.GetSubscriptionResource(SubscriptionResource.CreateResourceIdentifier(configuration.SubscriptionId)); + subscriptionIds = await armClient.GetSubscriptions().GetAllAsync(cts.Token).ToListAsync(cts.Token); }); await ValidateSubscriptionAndResourceGroupAsync(configuration); - kubernetesManager = new(configuration, azureCredentials, azureCloudConfig, cts.Token); - IResourceGroup resourceGroup = null; + kubernetesManager = new(configuration, azureCloudConfig, GetBlobClient, cts.Token); + ContainerServiceManagedClusterResource aksCluster = null; - BatchAccount batchAccount = null; - IGenericResource logAnalyticsWorkspace = null; - IGenericResource appInsights = null; - FlexibleServerModel.Server postgreSqlFlexServer = null; - IStorageAccount storageAccount = null; - var keyVaultUri = string.Empty; - IIdentity managedIdentity = null; - IPrivateDnsZone postgreSqlDnsZone = null; + BatchAccountResource batchAccount = null; + OperationalInsightsWorkspaceResource logAnalyticsWorkspace = null; + ApplicationInsightsComponentResource appInsights = null; + PostgreSqlFlexibleServerResource postgreSqlFlexServer = null; + StorageAccountResource storageAccount = null; + StorageAccountData storageAccountData = null; + Uri keyVaultUri = null; + UserAssignedIdentityResource managedIdentity = null; + PrivateDnsZoneResource postgreSqlDnsZone = null; var targetVersion = Utility.DelimitedTextToDictionary(Utility.GetFileContent("scripts", "env-00-tes-version.txt")).GetValueOrDefault("TesOnAzureVersion"); if (configuration.Update) { - resourceGroup = await azureSubscriptionClient.ResourceGroups.GetByNameAsync(configuration.ResourceGroupName, cts.Token); - configuration.RegionName = resourceGroup.RegionName; + resourceGroup = (await armSubscription.GetResourceGroupAsync(configuration.ResourceGroupName, cts.Token)).Value; + configuration.RegionName = resourceGroup.Id.Location ?? + ((await EnsureResourceDataAsync(resourceGroup, g => g.HasData, g => g.GetAsync, cts.Token, g => resourceGroup = g)).Data.Location.Name); - ConsoleEx.WriteLine($"Upgrading TES on Azure instance in resource group '{resourceGroup.Name}' to version {targetVersion}..."); + ConsoleEx.WriteLine($"Upgrading TES on Azure instance in resource group '{resourceGroup.Id.Name}' to version {targetVersion}..."); if (string.IsNullOrEmpty(configuration.StorageAccountName)) { - var storageAccounts = await (await azureSubscriptionClient.StorageAccounts.ListByResourceGroupAsync(configuration.ResourceGroupName, cancellationToken: cts.Token)) - .ToAsyncEnumerable().ToListAsync(cts.Token); + var storageAccounts = await resourceGroup.GetStorageAccounts().ToListAsync(cts.Token); storageAccount = storageAccounts.Count switch { @@ -202,10 +232,11 @@ await Execute("Connecting to Azure Services...", async () => ?? throw new ValidationException($"Storage account {configuration.StorageAccountName} does not exist in region {configuration.RegionName} or is not accessible to the current user.", displayExample: false); } + storageAccountData = (await FetchResourceDataAsync(ct => storageAccount.GetAsync(cancellationToken: ct), cts.Token, account => storageAccount = account)).Data; + if (string.IsNullOrWhiteSpace(configuration.AksClusterName)) { - var client = armClient.GetResourceGroupResource(new(resourceGroup.Id)); - var aksClusters = await client.GetContainerServiceManagedClusters().GetAllAsync(cts.Token).ToListAsync(cts.Token); + var aksClusters = await resourceGroup.GetContainerServiceManagedClusters().GetAllAsync(cts.Token).ToListAsync(cts.Token); aksCluster = aksClusters.Count switch { @@ -222,11 +253,11 @@ await Execute("Connecting to Azure Services...", async () => ?? throw new ValidationException($"AKS cluster {configuration.AksClusterName} does not exist in region {configuration.RegionName} or is not accessible to the current user.", displayExample: false); } - var aksValues = await kubernetesManager.GetAKSSettingsAsync(storageAccount); + var aksValues = await kubernetesManager.GetAKSSettingsAsync(storageAccountData); - if (!aksValues.Any()) + if (0 == aksValues.Count) { - throw new ValidationException($"Could not retrieve account names from stored configuration in {storageAccount.Name}.", displayExample: false); + throw new ValidationException($"Could not retrieve account names from stored configuration in {storageAccountData.Name}.", displayExample: false); } if (aksValues.TryGetValue("EnableIngress", out var enableIngress) && aksValues.TryGetValue("TesHostname", out var tesHostname)) @@ -270,7 +301,7 @@ await Execute("Connecting to Azure Services...", async () => if (!aksValues.TryGetValue("BatchAccountName", out var batchAccountName)) { - throw new ValidationException($"Could not retrieve the Batch account name from stored configuration in {storageAccount.Name}.", displayExample: false); + throw new ValidationException($"Could not retrieve the Batch account name from stored configuration in {storageAccount.Id.Name}.", displayExample: false); } batchAccount = await GetExistingBatchAccountAsync(batchAccountName) @@ -280,7 +311,7 @@ await Execute("Connecting to Azure Services...", async () => if (!aksValues.TryGetValue("PostgreSqlServerName", out var postgreSqlServerName)) { - throw new ValidationException($"Could not retrieve the PostgreSqlServer account name from stored configuration in {storageAccount.Name}.", displayExample: false); + throw new ValidationException($"Could not retrieve the PostgreSqlServer account name from stored configuration in {storageAccount.Id.Name}.", displayExample: false); } configuration.PostgreSqlServerName = postgreSqlServerName; @@ -292,8 +323,7 @@ await Execute("Connecting to Azure Services...", async () => if (aksValues.TryGetValue("KeyVaultName", out var keyVaultName)) { - var keyVault = await GetKeyVaultAsync(keyVaultName); - keyVaultUri = keyVault.Properties.VaultUri; + keyVaultUri = (await EnsureResourceDataAsync(await GetKeyVaultAsync(keyVaultName), vault => vault.HasData, vault => vault.GetAsync, cts.Token)).Data.Properties.VaultUri; } if (!aksValues.TryGetValue("ManagedIdentityClientId", out var managedIdentityClientId)) @@ -301,8 +331,10 @@ await Execute("Connecting to Azure Services...", async () => throw new ValidationException($"Could not retrieve ManagedIdentityClientId.", displayExample: false); } - managedIdentity = await (await azureSubscriptionClient.Identities.ListByResourceGroupAsync(configuration.ResourceGroupName, cancellationToken: cts.Token)) - .ToAsyncEnumerable().FirstOrDefaultAsync(id => id.ClientId == managedIdentityClientId, cts.Token) + var clientId = Guid.Parse(managedIdentityClientId); + managedIdentity = await resourceGroup.GetUserAssignedIdentities() + .SelectAwaitWithCancellation(async (id, ct) => await FetchResourceDataAsync(id.GetAsync, ct)) + .FirstOrDefaultAsync(id => id.Data.ClientId == clientId, cts.Token) ?? throw new ValidationException($"Managed Identity {managedIdentityClientId} does not exist in region {configuration.RegionName} or is not accessible to the current user.", displayExample: false); // Override any configuration that is used by the update. @@ -319,7 +351,7 @@ await Execute("Connecting to Azure Services...", async () => } } - var settings = ConfigureSettings(managedIdentity.ClientId, aksValues, installedVersion); + var settings = ConfigureSettings(managedIdentity.Data.ClientId?.ToString("D"), aksValues, installedVersion); var waitForRoleAssignmentPropagation = false; if (installedVersion is null || installedVersion < new Version(4, 4)) @@ -329,7 +361,7 @@ await Execute("Connecting to Azure Services...", async () => if (string.IsNullOrWhiteSpace(settings["BatchNodesSubnetId"])) { - settings["BatchNodesSubnetId"] = await UpdateVnetWithBatchSubnet(resourceGroup.Inner.Id); + settings["BatchNodesSubnetId"] = await UpdateVnetWithBatchSubnet(); } } @@ -377,11 +409,11 @@ await Execute("Connecting to Azure Services...", async () => if (waitForRoleAssignmentPropagation) { await Execute("Waiting 5 minutes for role assignment propagation...", - () => Task.Delay(System.TimeSpan.FromMinutes(5), cts.Token)); + () => Task.Delay(TimeSpan.FromMinutes(5), cts.Token)); } - await kubernetesManager.UpgradeValuesYamlAsync(storageAccount, settings); - await PerformHelmDeploymentAsync(resourceGroup); + await kubernetesManager.UpgradeValuesYamlAsync(storageAccountData, settings); + await PerformHelmDeploymentAsync(aksCluster); } if (!configuration.Update) @@ -393,109 +425,113 @@ await Execute("Connecting to Azure Services...", async () => configuration.BatchPrefix = blob.ConvertToBase32().TrimEnd('='); } - ValidateRegionName(configuration.RegionName); - ValidateMainIdentifierPrefix(configuration.MainIdentifierPrefix); - storageAccount = await ValidateAndGetExistingStorageAccountAsync(); - batchAccount = await ValidateAndGetExistingBatchAccountAsync(); - aksCluster = await ValidateAndGetExistingAKSClusterAsync(); - postgreSqlFlexServer = await ValidateAndGetExistingPostgresqlServerAsync(); - var keyVault = await ValidateAndGetExistingKeyVaultAsync(); - - if (aksCluster is null && !configuration.ManualHelmDeployment) + KeyVaultResource keyVault = default; + await Execute("Validating existing Azure resources...", async () => { - //await ValidateVmAsync(); - } + await ValidateRegionNameAsync(configuration.RegionName); + ValidateMainIdentifierPrefix(configuration.MainIdentifierPrefix); + storageAccount = await ValidateAndGetExistingStorageAccountAsync(); + batchAccount = await ValidateAndGetExistingBatchAccountAsync(); + aksCluster = await ValidateAndGetExistingAKSClusterAsync(); + postgreSqlFlexServer = await ValidateAndGetExistingPostgresqlServerAsync(); + var keyVault = await ValidateAndGetExistingKeyVaultAsync(); + + if (aksCluster is null && !configuration.ManualHelmDeployment) + { + //await ValidateVmAsync(); + } - ConsoleEx.WriteLine($"Deploying TES on Azure version {targetVersion}..."); + if (string.IsNullOrWhiteSpace(configuration.PostgreSqlServerNameSuffix)) + { + configuration.PostgreSqlServerNameSuffix = $".{azureCloudConfig.Suffixes.PostgresqlServerEndpointSuffix}"; + } - if (string.IsNullOrWhiteSpace(configuration.PostgreSqlServerNameSuffix)) - { - configuration.PostgreSqlServerNameSuffix = $".{azureCloudConfig.Suffixes.PostgresqlServerEndpointSuffix}"; - } + // Configuration preferences not currently settable by user. + if (string.IsNullOrWhiteSpace(configuration.PostgreSqlServerName)) + { + configuration.PostgreSqlServerName = Utility.RandomResourceName($"{configuration.MainIdentifierPrefix}-", 15); + } - // Configuration preferences not currently settable by user. - if (string.IsNullOrWhiteSpace(configuration.PostgreSqlServerName)) - { - configuration.PostgreSqlServerName = SdkContext.RandomResourceName($"{configuration.MainIdentifierPrefix}-", 15); - } + configuration.PostgreSqlAdministratorPassword = PasswordGenerator.GeneratePassword(); + configuration.PostgreSqlTesUserPassword = PasswordGenerator.GeneratePassword(); - configuration.PostgreSqlAdministratorPassword = PasswordGenerator.GeneratePassword(); - configuration.PostgreSqlTesUserPassword = PasswordGenerator.GeneratePassword(); + if (string.IsNullOrWhiteSpace(configuration.BatchAccountName)) + { + configuration.BatchAccountName = Utility.RandomResourceName($"{configuration.MainIdentifierPrefix}", 15); + } - if (string.IsNullOrWhiteSpace(configuration.BatchAccountName)) - { - configuration.BatchAccountName = SdkContext.RandomResourceName($"{configuration.MainIdentifierPrefix}", 15); - } + if (string.IsNullOrWhiteSpace(configuration.StorageAccountName)) + { + configuration.StorageAccountName = Utility.RandomResourceName($"{configuration.MainIdentifierPrefix}", 24); + } - if (string.IsNullOrWhiteSpace(configuration.StorageAccountName)) - { - configuration.StorageAccountName = SdkContext.RandomResourceName($"{configuration.MainIdentifierPrefix}", 24); - } + //if (string.IsNullOrWhiteSpace(configuration.NetworkSecurityGroupName)) + //{ + // configuration.NetworkSecurityGroupName = Utility.RandomResourceName($"{configuration.MainIdentifierPrefix}", 15); + //} - //if (string.IsNullOrWhiteSpace(configuration.NetworkSecurityGroupName)) - //{ - // configuration.NetworkSecurityGroupName = SdkContext.RandomResourceName($"{configuration.MainIdentifierPrefix}", 15); - //} + if (string.IsNullOrWhiteSpace(configuration.ApplicationInsightsAccountName)) + { + configuration.ApplicationInsightsAccountName = Utility.RandomResourceName($"{configuration.MainIdentifierPrefix}-", 15); + } - if (string.IsNullOrWhiteSpace(configuration.ApplicationInsightsAccountName)) - { - configuration.ApplicationInsightsAccountName = SdkContext.RandomResourceName($"{configuration.MainIdentifierPrefix}-", 15); - } + if (string.IsNullOrWhiteSpace(configuration.TesPassword)) + { + configuration.TesPassword = PasswordGenerator.GeneratePassword(); + } - if (string.IsNullOrWhiteSpace(configuration.TesPassword)) - { - configuration.TesPassword = PasswordGenerator.GeneratePassword(); - } + if (string.IsNullOrWhiteSpace(configuration.AksClusterName)) + { + configuration.AksClusterName = Utility.RandomResourceName($"{configuration.MainIdentifierPrefix}-", 25); + } - if (string.IsNullOrWhiteSpace(configuration.AksClusterName)) - { - configuration.AksClusterName = SdkContext.RandomResourceName($"{configuration.MainIdentifierPrefix}-", 25); - } + if (string.IsNullOrWhiteSpace(configuration.KeyVaultName)) + { + configuration.KeyVaultName = Utility.RandomResourceName($"{configuration.MainIdentifierPrefix}-", 15); + } - if (string.IsNullOrWhiteSpace(configuration.KeyVaultName)) - { - configuration.KeyVaultName = SdkContext.RandomResourceName($"{configuration.MainIdentifierPrefix}-", 15); - } + await RegisterResourceProvidersAsync(); + await RegisterResourceProviderFeaturesAsync(); - await RegisterResourceProvidersAsync(); - await RegisterResourceProviderFeaturesAsync(); + if (batchAccount is null) + { + await ValidateBatchAccountQuotaAsync(); + } + }); - if (batchAccount is null) - { - await ValidateBatchAccountQuotaAsync(); - } + ConsoleEx.WriteLine($"Deploying TES on Azure version {targetVersion}..."); var vnetAndSubnet = await ValidateAndGetExistingVirtualNetworkAsync(); if (string.IsNullOrWhiteSpace(configuration.ResourceGroupName)) { - configuration.ResourceGroupName = SdkContext.RandomResourceName($"{configuration.MainIdentifierPrefix}-", 15); + configuration.ResourceGroupName = Utility.RandomResourceName($"{configuration.MainIdentifierPrefix}-", 15); resourceGroup = await CreateResourceGroupAsync(); isResourceGroupCreated = true; } else { - resourceGroup = await azureSubscriptionClient.ResourceGroups.GetByNameAsync(configuration.ResourceGroupName, cts.Token); + resourceGroup = (await armSubscription.GetResourceGroupAsync(configuration.ResourceGroupName, cts.Token)).Value; } // Derive TES ingress URL from resource group name kubernetesManager.SetTesIngressNetworkingConfiguration(configuration.ResourceGroupName); - managedIdentity = await CreateUserManagedIdentityAsync(resourceGroup); + managedIdentity = await EnsureResourceDataAsync(await CreateUserManagedIdentityAsync(), id => id.HasData, id => id.GetAsync, cts.Token); if (vnetAndSubnet is not null) { - ConsoleEx.WriteLine($"Creating VM in existing virtual network {vnetAndSubnet.Value.virtualNetwork.Name} and subnet {vnetAndSubnet.Value.vmSubnet.Name}"); + ConsoleEx.WriteLine($"Creating VM in existing virtual network {vnetAndSubnet.Value.virtualNetwork.Id.Name} and subnet {vnetAndSubnet.Value.vmSubnet.Id.Name}"); } if (storageAccount is not null) { - ConsoleEx.WriteLine($"Using existing Storage Account {storageAccount.Name}"); + ConsoleEx.WriteLine($"Using existing Storage Account {storageAccount.Id.Name}"); } if (batchAccount is not null) { - ConsoleEx.WriteLine($"Using existing Batch Account {batchAccount.Name}"); + ConsoleEx.WriteLine($"Using existing Batch Account {batchAccount.Id.Name}"); } await Task.WhenAll( @@ -504,15 +540,15 @@ await Task.WhenAll( { if (vnetAndSubnet is null) { - configuration.VnetName = SdkContext.RandomResourceName($"{configuration.MainIdentifierPrefix}-", 15); + configuration.VnetName = Utility.RandomResourceName($"{configuration.MainIdentifierPrefix}-", 15); configuration.PostgreSqlSubnetName = string.IsNullOrEmpty(configuration.PostgreSqlSubnetName) ? configuration.DefaultPostgreSqlSubnetName : configuration.PostgreSqlSubnetName; configuration.BatchSubnetName = string.IsNullOrEmpty(configuration.BatchSubnetName) ? configuration.DefaultBatchSubnetName : configuration.BatchSubnetName; configuration.VmSubnetName = string.IsNullOrEmpty(configuration.VmSubnetName) ? configuration.DefaultVmSubnetName : configuration.VmSubnetName; - vnetAndSubnet = await CreateVnetAndSubnetsAsync(resourceGroup); + vnetAndSubnet = await CreateVnetAndSubnetsAsync(); - if (string.IsNullOrWhiteSpace(configuration.LogAnalyticsArmId)) + if (string.IsNullOrEmpty(configuration.BatchNodesSubnetId)) { - configuration.BatchNodesSubnetId = vnetAndSubnet.Value.batchSubnet.Inner.Id; + configuration.BatchNodesSubnetId = vnetAndSubnet.Value.batchSubnet.Id; } } }), @@ -520,16 +556,17 @@ await Task.WhenAll( { if (string.IsNullOrWhiteSpace(configuration.LogAnalyticsArmId)) { - var workspaceName = SdkContext.RandomResourceName(configuration.MainIdentifierPrefix, 15); + var workspaceName = Utility.RandomResourceName(configuration.MainIdentifierPrefix, 15); logAnalyticsWorkspace = await CreateLogAnalyticsWorkspaceResourceAsync(workspaceName); configuration.LogAnalyticsArmId = logAnalyticsWorkspace.Id; } }), Task.Run(async () => { - storageAccount ??= await CreateStorageAccountAsync(); + storageAccount = await EnsureResourceDataAsync(storageAccount ?? await CreateStorageAccountAsync(), r => r.HasData, r => ct => r.GetAsync(cancellationToken: ct), cts.Token); await CreateDefaultStorageContainersAsync(storageAccount); - await WritePersonalizedFilesToStorageAccountAsync(storageAccount); + storageAccountData = storageAccount.Data; + await WritePersonalizedFilesToStorageAccountAsync(storageAccountData); await AssignVmAsContributorToStorageAccountAsync(managedIdentity, storageAccount); await AssignVmAsDataOwnerToStorageAccountAsync(managedIdentity, storageAccount); await AssignManagedIdOperatorToResourceAsync(managedIdentity, resourceGroup); @@ -541,10 +578,10 @@ await Task.WhenAll( { await Task.Run(async () => { - keyVault ??= await CreateKeyVaultAsync(configuration.KeyVaultName, managedIdentity, vnetAndSubnet.Value.vmSubnet); - keyVaultUri = keyVault.Properties.VaultUri; - var keys = await storageAccount.GetKeysAsync(); - await SetStorageKeySecret(keyVaultUri, StorageAccountKeySecretName, keys[0].Value); + keyVault ??= await CreateKeyVaultAsync(configuration.KeyVaultName, managedIdentity, vnetAndSubnet.Value.virtualNetwork, vnetAndSubnet.Value.vmSubnet); + keyVaultUri = (await EnsureResourceDataAsync(keyVault, r => r.HasData, r => r.GetAsync, cts.Token)).Data.Properties.VaultUri; + var key = await storageAccount.GetKeysAsync(cancellationToken: cts.Token).FirstAsync(cts.Token); + await SetStorageKeySecret(keyVaultUri, StorageAccountKeySecretName, key.Value); }); } @@ -559,7 +596,7 @@ await Task.WhenAll( { if (aksCluster is null && !configuration.ManualHelmDeployment) { - aksCluster = await ProvisionManagedClusterAsync(resourceGroup, managedIdentity, logAnalyticsWorkspace, vnetAndSubnet?.virtualNetwork, vnetAndSubnet?.vmSubnet.Name, configuration.PrivateNetworking.GetValueOrDefault()); + aksCluster = await ProvisionManagedClusterAsync(managedIdentity, logAnalyticsWorkspace, vnetAndSubnet?.vmSubnet.Id, configuration.PrivateNetworking.GetValueOrDefault()); await EnableWorkloadIdentity(aksCluster, managedIdentity, resourceGroup); } }), @@ -570,20 +607,20 @@ await Task.WhenAll( }), Task.Run(async () => { - appInsights = await CreateAppInsightsResourceAsync(configuration.LogAnalyticsArmId); + appInsights = await CreateAppInsightsResourceAsync(new(configuration.LogAnalyticsArmId)); await AssignVmAsContributorToAppInsightsAsync(managedIdentity, appInsights); }), Task.Run(async () => { - postgreSqlFlexServer ??= await CreatePostgreSqlServerAndDatabaseAsync(postgreSqlFlexManagementClient, vnetAndSubnet.Value.postgreSqlSubnet, postgreSqlDnsZone); + postgreSqlFlexServer ??= await CreatePostgreSqlServerAndDatabaseAsync(vnetAndSubnet.Value.postgreSqlSubnet, postgreSqlDnsZone); }) ]); - var clientId = managedIdentity.ClientId; - var settings = ConfigureSettings(clientId); + var clientId = managedIdentity.Data.ClientId; + var settings = ConfigureSettings(clientId?.ToString("D")); - await kubernetesManager.UpdateHelmValuesAsync(storageAccount, keyVaultUri, resourceGroup.Name, settings, managedIdentity); - await PerformHelmDeploymentAsync(resourceGroup, + await kubernetesManager.UpdateHelmValuesAsync(storageAccountData, keyVaultUri, resourceGroup.Id.Name, settings, managedIdentity.Data); + await PerformHelmDeploymentAsync(aksCluster, [ "Run the following postgresql command to setup the database.", $"\tPostgreSQL command: psql postgresql://{configuration.PostgreSqlAdministratorLogin}:{configuration.PostgreSqlAdministratorPassword}@{configuration.PostgreSqlServerName}.{azureCloudConfig.Suffixes.PostgresqlServerEndpointSuffix}/{configuration.PostgreSqlTesDatabaseName} -c \"{GetCreateTesUserString()}\"" @@ -625,10 +662,11 @@ await Execute( ConsoleEx.WriteLine($"TES credentials file written to: {credentialsPath}"); } - var maxPerFamilyQuota = batchAccount.DedicatedCoreQuotaPerVMFamilyEnforced ? batchAccount.DedicatedCoreQuotaPerVMFamily.Select(q => q.CoreQuota).Where(q => 0 != q) : Enumerable.Repeat(batchAccount.DedicatedCoreQuota ?? 0, 1); - var isBatchQuotaAvailable = batchAccount.LowPriorityCoreQuota > 0 || (batchAccount.DedicatedCoreQuota > 0 && maxPerFamilyQuota.Append(0).Max() > 0); - var isBatchPoolQuotaAvailable = batchAccount.PoolQuota > 0; - var isBatchJobQuotaAvailable = batchAccount.ActiveJobAndJobScheduleQuota > 0; + var batchAccountData = (await EnsureResourceDataAsync(batchAccount, r => r.HasData, r => r.GetAsync, cts.Token)).Data; + var maxPerFamilyQuota = batchAccountData.IsDedicatedCoreQuotaPerVmFamilyEnforced ?? false ? batchAccountData.DedicatedCoreQuotaPerVmFamily.Select(q => q.CoreQuota ?? 0).Where(q => 0 != q) : Enumerable.Repeat(batchAccountData.DedicatedCoreQuota ?? 0, 1); + var isBatchQuotaAvailable = batchAccountData.LowPriorityCoreQuota > 0 || (batchAccountData.DedicatedCoreQuota > 0 && maxPerFamilyQuota.Append(0).Max() > 0); + var isBatchPoolQuotaAvailable = batchAccountData.PoolQuota > 0; + var isBatchJobQuotaAvailable = batchAccountData.ActiveJobAndJobScheduleQuota > 0; var insufficientQuotas = new List(); int exitCode; @@ -636,7 +674,7 @@ await Execute( if (!isBatchPoolQuotaAvailable) insufficientQuotas.Add("pool"); if (!isBatchJobQuotaAvailable) insufficientQuotas.Add("job"); - if (insufficientQuotas.Any()) + if (0 != insufficientQuotas.Count) { if (!configuration.SkipTestWorkflow) { @@ -670,7 +708,7 @@ await Execute( var portForwardTask = startPortForward(tokenSource.Token); await Task.Delay(longRetryWaitTime * 2, tokenSource.Token); // Give enough time for kubectl to standup the port forwarding. - var runTestTask = RunTestTaskAsync("localhost:8088", batchAccount.LowPriorityCoreQuota > 0); + var runTestTask = RunTestTaskAsync("localhost:8088", isPreemptible: batchAccountData.LowPriorityCoreQuota > 0); for (var task = await Task.WhenAny(portForwardTask, runTestTask); runTestTask != task; @@ -785,7 +823,7 @@ await Execute( } } - private async Task PerformHelmDeploymentAsync(IResourceGroup resourceGroup, IEnumerable manualPrecommands = default, Func asyncTask = default) + private async Task PerformHelmDeploymentAsync(ContainerServiceManagedClusterResource cluster, IEnumerable manualPrecommands = default, Func asyncTask = default) { if (configuration.ManualHelmDeployment) { @@ -802,7 +840,7 @@ private async Task PerformHelmDeploymentAsync(IResourceGroup resourceGroup, IEnu } else { - var kubernetesClient = await kubernetesManager.GetKubernetesClientAsync(resourceGroup); + var kubernetesClient = await kubernetesManager.GetKubernetesClientAsync(cluster); await (asyncTask?.Invoke(kubernetesClient) ?? Task.CompletedTask); await kubernetesManager.DeployHelmChartToClusterAsync(kubernetesClient); } @@ -858,7 +896,7 @@ private async Task RunTestTaskAsync(string tesEndpoint, bool isPreemptible return isTestWorkflowSuccessful; } - private async Task ValidateAndGetExistingKeyVaultAsync() + private async Task ValidateAndGetExistingKeyVaultAsync() { if (string.IsNullOrWhiteSpace(configuration.KeyVaultName)) { @@ -869,7 +907,7 @@ private async Task ValidateAndGetExistingKeyVaultAsync() ?? throw new ValidationException($"If key vault name is provided, it must already exist in region {configuration.RegionName}, and be accessible to the current user.", displayExample: false); } - private async Task ValidateAndGetExistingPostgresqlServerAsync() + private async Task ValidateAndGetExistingPostgresqlServerAsync() { if (string.IsNullOrWhiteSpace(configuration.PostgreSqlServerName)) { @@ -891,15 +929,13 @@ private async Task ValidateAndGetExistin ?? throw new ValidationException($"If AKS cluster name is provided, the cluster must already exist in region {configuration.RegionName}, and be accessible to the current user.", displayExample: false); } - private async Task GetExistingPostgresqlServiceAsync(string serverName) + private async Task GetExistingPostgresqlServiceAsync(string serverName) { - var regex = new Regex(@"\s+"); - return await subscriptionIds.ToAsyncEnumerable().SelectAwait(async s => + return await subscriptionIds.ToAsyncEnumerable().Select(s => { try { - var client = new FlexibleServer.PostgreSQLManagementClient(tokenCredentials) { SubscriptionId = s }; - return (await client.Servers.ListAsync(cts.Token)).ToAsyncEnumerable(client.Servers.ListNextAsync); + return s.GetPostgreSqlFlexibleServersAsync(cts.Token); } catch (Exception e) { @@ -909,9 +945,10 @@ private async Task ValidateAndGetExistin }) .Where(a => a is not null) .SelectMany(a => a) + .SelectAwaitWithCancellation(async (a, ct) => await FetchResourceDataAsync(a.GetAsync, ct)) .SingleOrDefaultAsync(a => - a.Name.Equals(serverName, StringComparison.OrdinalIgnoreCase) && - regex.Replace(a.Location, string.Empty).Equals(configuration.RegionName, StringComparison.OrdinalIgnoreCase), + a.Id.Name.Equals(serverName, StringComparison.OrdinalIgnoreCase) && + a.Data.Location.Name.Equals(configuration.RegionName, StringComparison.OrdinalIgnoreCase), cts.Token); } @@ -919,8 +956,7 @@ private async Task GetExistingAKSCluster { return await subscriptionIds.ToAsyncEnumerable() .SelectAwaitWithCancellation((sub, token) => ValueTask.FromResult>( - armClient.GetSubscriptionResource(SubscriptionResource.CreateResourceIdentifier(sub)) - .GetContainerServiceManagedClustersAsync(token))) + sub.GetContainerServiceManagedClustersAsync(token))) .Where(a => a is not null) .SelectMany(a => a) .SelectAwaitWithCancellation((resource, token) => @@ -945,10 +981,9 @@ static async ValueTask SafeSelectAsync(Func> selecto } } - private async Task ProvisionManagedClusterAsync(IResource resourceGroupObject, IIdentity managedIdentity, IGenericResource logAnalyticsWorkspace, INetwork virtualNetwork, string subnetName, bool privateNetworking) + private async Task ProvisionManagedClusterAsync(UserAssignedIdentityResource managedIdentity, OperationalInsightsWorkspaceResource logAnalyticsWorkspace, ResourceIdentifier subnetId, bool privateNetworking) { - var uami = (await armClient.GetUserAssignedIdentityResource(new(managedIdentity.Id)).GetAsync(cts.Token)).Value; - var resourceGroup = armClient.GetResourceGroupResource(new(resourceGroupObject.Id)); + var uami = await EnsureResourceDataAsync(managedIdentity, r => r.HasData, r => r.GetAsync, cts.Token); var nodePoolName = "nodepool1"; ContainerServiceManagedClusterData cluster = new(new(configuration.RegionName)) { @@ -962,6 +997,7 @@ private async Task ProvisionManagedClust NetworkPolicy = ContainerServiceNetworkPolicy.Azure } }; + ManagedClusterAddonProfile clusterAddonProfile = new(isEnabled: true); clusterAddonProfile.Config.Add("logAnalyticsWorkspaceResourceID", logAnalyticsWorkspace.Id); cluster.AddonProfiles.Add("omsagent", clusterAddonProfile); @@ -996,7 +1032,7 @@ private async Task ProvisionManagedClust OSType = ContainerServiceOSType.Linux, OSSku = ContainerServiceOSSku.AzureLinux, Mode = AgentPoolMode.System, - VnetSubnetId = new(virtualNetwork.Subnets[subnetName].Inner.Id), + VnetSubnetId = subnetId, }); if (privateNetworking) @@ -1015,23 +1051,19 @@ private async Task ProvisionManagedClust async () => (await resourceGroup.GetContainerServiceManagedClusters().CreateOrUpdateAsync(Azure.WaitUntil.Completed, configuration.AksClusterName, cluster, cts.Token)).Value); } - private async Task EnableWorkloadIdentity(ContainerServiceManagedClusterResource aksCluster, IIdentity managedIdentity, IResourceGroup resourceGroup) + private async Task EnableWorkloadIdentity(ContainerServiceManagedClusterResource aksCluster, UserAssignedIdentityResource managedIdentity, ResourceGroupResource resourceGroup) { - // Use the new ResourceManager sdk enable workload identity. - var armCluster = aksCluster.HasData ? aksCluster : (await aksCluster.GetAsync(cts.Token)).Value; - armCluster.Data.SecurityProfile.IsWorkloadIdentityEnabled = true; - armCluster.Data.OidcIssuerProfile.IsEnabled = true; - var coaRg = armClient.GetResourceGroupResource(new ResourceIdentifier(resourceGroup.Id)); - var aksClusterCollection = coaRg.GetContainerServiceManagedClusters(); - var cluster = await aksClusterCollection.CreateOrUpdateAsync(Azure.WaitUntil.Completed, armCluster.Data.Name, armCluster.Data, cts.Token); + aksCluster.Data.SecurityProfile.IsWorkloadIdentityEnabled = true; + aksCluster.Data.OidcIssuerProfile.IsEnabled = true; + var aksClusterCollection = resourceGroup.GetContainerServiceManagedClusters(); + var cluster = await aksClusterCollection.CreateOrUpdateAsync(Azure.WaitUntil.Completed, aksCluster.Data.Name, aksCluster.Data, cts.Token); var aksOidcIssuer = cluster.Value.Data.OidcIssuerProfile.IssuerUriInfo; - var uami = armClient.GetUserAssignedIdentityResource(new ResourceIdentifier(managedIdentity.Id)); - var federatedCredentialsCollection = uami.GetFederatedIdentityCredentials(); + var federatedCredentialsCollection = managedIdentity.GetFederatedIdentityCredentials(); var data = new FederatedIdentityCredentialData() { IssuerUri = new Uri(aksOidcIssuer), - Subject = $"system:serviceaccount:{configuration.AksCoANamespace}:{managedIdentity.Name}-sa" + Subject = $"system:serviceaccount:{configuration.AksCoANamespace}:{managedIdentity.Id.Name}-sa" }; data.Audiences.Add("api://AzureADTokenExchange"); @@ -1212,19 +1244,6 @@ private static void UpdateSetting(Dictionary settings, Dictio settings[key] = valueIsNullOrEmpty ? GetDefault() : ConvertValue(value); } - private static Microsoft.Azure.Management.Fluent.Azure.IAuthenticated GetAzureClient(AzureCredentials azureCredentials) - => Microsoft.Azure.Management.Fluent.Azure - .Configure() - .WithLogLevel(HttpLoggingDelegatingHandler.Level.Basic) - .Authenticate(azureCredentials); - - private IResourceManager GetResourceManagerClient(AzureCredentials azureCredentials) - => ResourceManager - .Configure() - .WithLogLevel(HttpLoggingDelegatingHandler.Level.Basic) - .Authenticate(azureCredentials) - .WithSubscription(configuration.SubscriptionId); - private async Task RegisterResourceProvidersAsync() { var unregisteredResourceProviders = await GetRequiredResourceProvidersNotRegisteredAsync(); @@ -1242,7 +1261,7 @@ await Execute( { await Task.WhenAll( unregisteredResourceProviders.Select(rp => - resourceManagerClient.Providers.RegisterAsync(rp, cts.Token)) + rp.RegisterAsync(cancellationToken: cts.Token)) ); // RP registration takes a few minutes; poll until done registering @@ -1256,7 +1275,7 @@ await Task.WhenAll( break; } - await Task.Delay(System.TimeSpan.FromSeconds(15), cts.Token); + await Task.Delay(TimeSpan.FromSeconds(15), cts.Token); } }); } @@ -1272,7 +1291,7 @@ await Task.WhenAll( ConsoleEx.WriteLine("2. Select Subscription -> Resource Providers", ConsoleColor.Yellow); ConsoleEx.WriteLine("3. Select each of the following and click Register:", ConsoleColor.Yellow); ConsoleEx.WriteLine(); - unregisteredResourceProviders.ForEach(rp => ConsoleEx.WriteLine($"- {rp}", ConsoleColor.Yellow)); + unregisteredResourceProviders.ForEach(rp => ConsoleEx.WriteLine($"- {rp.Data.Namespace}", ConsoleColor.Yellow)); ConsoleEx.WriteLine(); ConsoleEx.WriteLine("After completion, please re-attempt deployment."); @@ -1280,14 +1299,14 @@ await Task.WhenAll( } } - private async Task> GetRequiredResourceProvidersNotRegisteredAsync() + private async ValueTask> GetRequiredResourceProvidersNotRegisteredAsync() { - var cloudResourceProviders = (await resourceManagerClient.Providers.ListAsync(cancellationToken: cts.Token)).ToAsyncEnumerable(); + var cloudResourceProviders = armSubscription.GetResourceProviders().GetAllAsync(cancellationToken: cts.Token); - var notRegisteredResourceProviders = await requiredResourceProviders.ToAsyncEnumerable() - .Intersect(cloudResourceProviders - .Where(rp => !rp.RegistrationState.Equals("Registered", StringComparison.OrdinalIgnoreCase)) - .Select(rp => rp.Namespace), StringComparer.OrdinalIgnoreCase) + var notRegisteredResourceProviders = await cloudResourceProviders + .SelectAwaitWithCancellation(async (rp, ct) => await FetchResourceDataAsync(token => rp.GetAsync(cancellationToken: token), ct)) + .Where(rp => requiredResourceProviders.Contains(rp.Data.Namespace, StringComparer.OrdinalIgnoreCase)) + .Where(rp => !rp.Data.RegistrationState.Equals("Registered", StringComparison.OrdinalIgnoreCase)) .ToListAsync(cts.Token); return notRegisteredResourceProviders; @@ -1300,48 +1319,46 @@ private async Task RegisterResourceProviderFeaturesAsync() { await Execute( $"Registering resource provider features...", - async () => - { - var subscription = armClient.GetSubscriptionResource(new($"/subscriptions/{configuration.SubscriptionId}")); - - foreach (var rpName in requiredResourceProviderFeatures.Keys) + async () => { - var rp = await subscription.GetResourceProviderAsync(rpName, cancellationToken: cts.Token); - - foreach (var featureName in requiredResourceProviderFeatures[rpName]) + foreach (var rpName in requiredResourceProviderFeatures.Keys) { - var feature = await rp.Value.GetFeatureAsync(featureName, cts.Token); + var rp = await armSubscription.GetResourceProviderAsync(rpName, cancellationToken: cts.Token); - if (!string.Equals(feature.Value.Data.FeatureState, "Registered", StringComparison.OrdinalIgnoreCase)) + foreach (var featureName in requiredResourceProviderFeatures[rpName]) { - unregisteredFeatures.Add(feature); - _ = await feature.Value.RegisterAsync(cts.Token); + var feature = await rp.Value.GetFeatureAsync(featureName, cts.Token); + + if (!string.Equals(feature.Value.Data.FeatureState, "Registered", StringComparison.OrdinalIgnoreCase)) + { + unregisteredFeatures.Add(feature); + _ = await feature.Value.RegisterAsync(cts.Token); + } } } - } - while (!cts.IsCancellationRequested) - { - if (unregisteredFeatures.Count == 0) + while (!cts.IsCancellationRequested) { - break; - } - - await Task.Delay(System.TimeSpan.FromSeconds(30), cts.Token); - var finished = new List(); + if (unregisteredFeatures.Count == 0) + { + break; + } - foreach (var feature in unregisteredFeatures) - { - var update = await feature.GetAsync(cts.Token); + await Task.Delay(TimeSpan.FromSeconds(30), cts.Token); + var finished = new List(); - if (string.Equals(update.Value.Data.FeatureState, "Registered", StringComparison.OrdinalIgnoreCase)) + foreach (var feature in unregisteredFeatures) { - finished.Add(feature); + var update = await feature.GetAsync(cts.Token); + + if (string.Equals(update.Value.Data.FeatureState, "Registered", StringComparison.OrdinalIgnoreCase)) + { + finished.Add(feature); + } } + unregisteredFeatures.RemoveAll(x => finished.Contains(x)); } - unregisteredFeatures.RemoveAll(x => finished.Contains(x)); - } - }); + }); } catch (Microsoft.Rest.Azure.CloudException ex) when (ex.ToCloudErrorType() == CloudErrorType.AuthorizationFailed) { @@ -1361,7 +1378,7 @@ await Execute( } } - private async Task TryAssignMIAsNetworkContributorToResourceAsync(IIdentity managedIdentity, IResource resource) + private async Task TryAssignMIAsNetworkContributorToResourceAsync(UserAssignedIdentityResource managedIdentity, ArmResource resource) { try { @@ -1376,40 +1393,37 @@ private async Task TryAssignMIAsNetworkContributorToResourceAsync(IIdentit } } - private Task AssignMIAsNetworkContributorToResourceAsync(IIdentity managedIdentity, IResource resource, bool cancelOnException = true) + private Task AssignMIAsNetworkContributorToResourceAsync(UserAssignedIdentityResource managedIdentity, ArmResource resource, bool cancelOnException = true) { // https://learn.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#network-contributor - var roleDefinitionId = $"/subscriptions/{configuration.SubscriptionId}/providers/Microsoft.Authorization/roleDefinitions/4d97b98b-1d4f-4787-a291-c67834d212e7"; + var roleDefinitionId = AuthorizationRoleDefinitionResource.CreateResourceIdentifier(SubscriptionResource.CreateResourceIdentifier(configuration.SubscriptionId), new("4d97b98b-1d4f-4787-a291-c67834d212e7")); return Execute( - $"Assigning Network Contributor role for the managed id to resource group scope...", + "Assigning 'Network Contributor' role for the managed id to resource group scope...", () => roleAssignmentHashConflictRetryPolicy.ExecuteAsync( - ct => azureSubscriptionClient.AccessManagement.RoleAssignments - .Define(Guid.NewGuid().ToString()) - .ForObjectId(managedIdentity.PrincipalId) - .WithRoleDefinition(roleDefinitionId) - .WithResourceScope(resource) - .CreateAsync(ct), + ct => (Task)resource.GetRoleAssignments().CreateOrUpdateAsync(WaitUntil.Completed, Guid.NewGuid().ToString(), + new(roleDefinitionId, managedIdentity.Data.PrincipalId.Value) + { + PrincipalType = Azure.ResourceManager.Authorization.Models.RoleManagementPrincipalType.ServicePrincipal + }, ct), cts.Token), cancelOnException: cancelOnException); } - private Task AssignManagedIdOperatorToResourceAsync(IIdentity managedIdentity, IResource resource) + private Task AssignManagedIdOperatorToResourceAsync(UserAssignedIdentityResource managedIdentity, ArmResource resource) { // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#managed-identity-operator - var roleDefinitionId = $"/subscriptions/{configuration.SubscriptionId}/providers/Microsoft.Authorization/roleDefinitions/f1a07417-d97a-45cb-824c-7a7467783830"; + var roleDefinitionId = AuthorizationRoleDefinitionResource.CreateResourceIdentifier(SubscriptionResource.CreateResourceIdentifier(configuration.SubscriptionId), new("f1a07417-d97a-45cb-824c-7a7467783830")); return Execute( - $"Assigning Managed ID Operator role for the managed id to resource group scope...", + "Assigning 'Managed ID Operator' role for the managed id to resource group scope...", () => roleAssignmentHashConflictRetryPolicy.ExecuteAsync( - ct => azureSubscriptionClient.AccessManagement.RoleAssignments - .Define(Guid.NewGuid().ToString()) - .ForObjectId(managedIdentity.PrincipalId) - .WithRoleDefinition(roleDefinitionId) - .WithResourceScope(resource) - .CreateAsync(ct), - cts.Token)); + ct => (Task)resource.GetRoleAssignments().CreateOrUpdateAsync(WaitUntil.Completed, Guid.NewGuid().ToString(), + new(roleDefinitionId, managedIdentity.Data.PrincipalId.Value) + { + PrincipalType = Azure.ResourceManager.Authorization.Models.RoleManagementPrincipalType.ServicePrincipal + }, ct), cts.Token)); } - private async Task TryAssignVmAsDataOwnerToStorageAccountAsync(IIdentity managedIdentity, IStorageAccount storageAccount) + private async Task TryAssignVmAsDataOwnerToStorageAccountAsync(UserAssignedIdentityResource managedIdentity, StorageAccountResource storageAccount) { try { @@ -1424,54 +1438,51 @@ private async Task TryAssignVmAsDataOwnerToStorageAccountAsync(IIdentity m } } - private Task AssignVmAsDataOwnerToStorageAccountAsync(IIdentity managedIdentity, IStorageAccount storageAccount, bool cancelOnException = true) + private Task AssignVmAsDataOwnerToStorageAccountAsync(UserAssignedIdentityResource managedIdentity, StorageAccountResource storageAccount, bool cancelOnException = true) { //https://learn.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#storage-blob-data-owner - var roleDefinitionId = $"/subscriptions/{configuration.SubscriptionId}/providers/Microsoft.Authorization/roleDefinitions/b7e6dc6d-f1e8-4753-8033-0f276bb0955b"; + ResourceIdentifier roleDefinitionId = new($"/subscriptions/{configuration.SubscriptionId}/providers/Microsoft.Authorization/roleDefinitions/b7e6dc6d-f1e8-4753-8033-0f276bb0955b"); return Execute( - $"Assigning Storage Blob Data Owner role for user-managed identity to Storage Account resource scope...", + "Assigning 'Storage Blob Data Owner' role for user-managed identity to Storage Account resource scope...", () => roleAssignmentHashConflictRetryPolicy.ExecuteAsync( - ct => azureSubscriptionClient.AccessManagement.RoleAssignments - .Define(Guid.NewGuid().ToString()) - .ForObjectId(managedIdentity.PrincipalId) - .WithRoleDefinition(roleDefinitionId) - .WithResourceScope(storageAccount) - .CreateAsync(ct), + ct => (Task)storageAccount.GetRoleAssignments().CreateOrUpdateAsync(WaitUntil.Completed, Guid.NewGuid().ToString(), + new(roleDefinitionId, managedIdentity.Data.PrincipalId.Value) + { + PrincipalType = Azure.ResourceManager.Authorization.Models.RoleManagementPrincipalType.ServicePrincipal + }, ct), cts.Token), cancelOnException: cancelOnException); } - private Task AssignVmAsContributorToStorageAccountAsync(IIdentity managedIdentity, IResource storageAccount) + private Task AssignVmAsContributorToStorageAccountAsync(UserAssignedIdentityResource managedIdentity, StorageAccountResource storageAccount) => Execute( - $"Assigning {BuiltInRole.Contributor} role for user-managed identity to Storage Account resource scope...", + "Assigning 'Contributor' role for user-managed identity to Storage Account resource scope...", () => roleAssignmentHashConflictRetryPolicy.ExecuteAsync( - ct => azureSubscriptionClient.AccessManagement.RoleAssignments - .Define(Guid.NewGuid().ToString()) - .ForObjectId(managedIdentity.PrincipalId) - .WithBuiltInRole(BuiltInRole.Contributor) - .WithResourceScope(storageAccount) - .CreateAsync(ct), - cts.Token)); - - private Task CreateStorageAccountAsync() + ct => (Task)storageAccount.GetRoleAssignments().CreateOrUpdateAsync(WaitUntil.Completed, Guid.NewGuid().ToString(), + new(All_Role_Contributor, managedIdentity.Data.PrincipalId.Value) + { + PrincipalType = Azure.ResourceManager.Authorization.Models.RoleManagementPrincipalType.ServicePrincipal + }, ct), cts.Token)); + + private Task CreateStorageAccountAsync() => Execute( $"Creating Storage Account: {configuration.StorageAccountName}...", - () => azureSubscriptionClient.StorageAccounts - .Define(configuration.StorageAccountName) - .WithRegion(configuration.RegionName) - .WithExistingResourceGroup(configuration.ResourceGroupName) - .WithGeneralPurposeAccountKindV2() - .WithOnlyHttpsTraffic() - .WithSku(StorageAccountSkuType.Standard_LRS) - .CreateAsync(cts.Token)); - - private async Task GetExistingStorageAccountAsync(string storageAccountName) - => await subscriptionIds.ToAsyncEnumerable().SelectAwait(async s => + async () => (await resourceGroup.GetStorageAccounts().CreateOrUpdateAsync(WaitUntil.Completed, + configuration.StorageAccountName, + new( + new(Storage.StorageSkuName.StandardLrs), + Storage.StorageKind.StorageV2, + new(configuration.RegionName)) + { EnableHttpsTrafficOnly = true }, + cts.Token)).Value); + + private async Task GetExistingStorageAccountAsync(string storageAccountName) + => await subscriptionIds.ToAsyncEnumerable().Select(s => { try { - return (await azureClient.WithSubscription(s).StorageAccounts.ListAsync(cancellationToken: cts.Token)).ToAsyncEnumerable(); + return s.GetStorageAccountsAsync(cts.Token); } catch (Exception) { @@ -1481,19 +1492,18 @@ private async Task GetExistingStorageAccountAsync(string storag }) .Where(a => a is not null) .SelectMany(a => a) + .SelectAwaitWithCancellation(async (a, ct) => await FetchResourceDataAsync(token => a.GetAsync(cancellationToken: token), ct)) .SingleOrDefaultAsync(a => - a.Name.Equals(storageAccountName, StringComparison.OrdinalIgnoreCase) && - a.RegionName.Equals(configuration.RegionName, StringComparison.OrdinalIgnoreCase), + a.Id.Name.Equals(storageAccountName, StringComparison.OrdinalIgnoreCase) && + a.Data.Location.Name.Equals(configuration.RegionName, StringComparison.OrdinalIgnoreCase), cts.Token); - private async Task GetExistingBatchAccountAsync(string batchAccountName) - => await subscriptionIds.ToAsyncEnumerable().SelectAwait(async s => + private async Task GetExistingBatchAccountAsync(string batchAccountName) + => await subscriptionIds.ToAsyncEnumerable().Select(s => { try { - var client = new BatchManagementClient(tokenCredentials) { SubscriptionId = s }; - return (await client.BatchAccount.ListAsync(cts.Token)) - .ToAsyncEnumerable(client.BatchAccount.ListNextAsync); + return s.GetBatchAccountsAsync(cts.Token); } catch (Exception e) { @@ -1503,72 +1513,73 @@ private async Task GetExistingBatchAccountAsync(string batchAccoun }) .Where(a => a is not null) .SelectMany(a => a) + .SelectAwaitWithCancellation(async (a, ct) => await FetchResourceDataAsync(a.GetAsync, ct)) .SingleOrDefaultAsync(a => - a.Name.Equals(batchAccountName, StringComparison.OrdinalIgnoreCase) && - a.Location.Equals(configuration.RegionName, StringComparison.OrdinalIgnoreCase), + a.Id.Name.Equals(batchAccountName, StringComparison.OrdinalIgnoreCase) && + a.Data.Location.Value.Name.Equals(configuration.RegionName, StringComparison.OrdinalIgnoreCase), cts.Token); - private async Task CreateDefaultStorageContainersAsync(IStorageAccount storageAccount) + private async Task CreateDefaultStorageContainersAsync(StorageAccountResource storageAccount) { - var blobClient = await GetBlobClientAsync(storageAccount, cts.Token); + List defaultContainers = [TesInternalContainerName, InputsContainerName, "outputs", ConfigurationContainerName]; - var defaultContainers = new List { TesInternalContainerName, InputsContainerName, "outputs", ConfigurationContainerName }; - await Task.WhenAll(defaultContainers.Select(c => blobClient.GetBlobContainerClient(c).CreateIfNotExistsAsync(cancellationToken: cts.Token))); + var containerCollection = storageAccount.GetBlobService().GetBlobContainers(); + await Task.WhenAll(await defaultContainers.ToAsyncEnumerable() + .Select(name => containerCollection.CreateOrUpdateAsync(WaitUntil.Completed, name, new(), cts.Token)) + .ToArrayAsync(cts.Token)); } - private Task WritePersonalizedFilesToStorageAccountAsync(IStorageAccount storageAccount) + private Task WritePersonalizedFilesToStorageAccountAsync(StorageAccountData storageAccount) => Execute( $"Writing {AllowedVmSizesFileName} file to '{TesInternalContainerName}' storage container...", async () => { - await UploadTextToStorageAccountAsync(storageAccount, TesInternalContainerName, $"{ConfigurationContainerName}/{AllowedVmSizesFileName}", Utility.GetFileContent("scripts", AllowedVmSizesFileName), cts.Token); + await UploadTextToStorageAccountAsync(GetBlobClient(storageAccount, TesInternalContainerName, $"{ConfigurationContainerName}/{AllowedVmSizesFileName}"), Utility.GetFileContent("scripts", AllowedVmSizesFileName), cts.Token); }); - private Task AssignVmAsContributorToBatchAccountAsync(IIdentity managedIdentity, BatchAccount batchAccount) + private Task AssignVmAsContributorToBatchAccountAsync(UserAssignedIdentityResource managedIdentity, BatchAccountResource batchAccount) => Execute( - $"Assigning {BuiltInRole.Contributor} role for user-managed identity to Batch Account resource scope...", + "Assigning 'Contributor' role for user-managed identity to Batch Account resource scope...", () => roleAssignmentHashConflictRetryPolicy.ExecuteAsync( - ct => azureSubscriptionClient.AccessManagement.RoleAssignments - .Define(Guid.NewGuid().ToString()) - .ForObjectId(managedIdentity.PrincipalId) - .WithBuiltInRole(BuiltInRole.Contributor) - .WithScope(batchAccount.Id) - .CreateAsync(ct), - cts.Token)); - - private async Task CreatePostgreSqlServerAndDatabaseAsync(FlexibleServer.IPostgreSQLManagementClient postgresManagementClient, ISubnet subnet, IPrivateDnsZone postgreSqlDnsZone) + ct => (Task)batchAccount.GetRoleAssignments().CreateOrUpdateAsync(WaitUntil.Completed, Guid.NewGuid().ToString(), + new(All_Role_Contributor, managedIdentity.Data.PrincipalId.Value) + { + PrincipalType = Azure.ResourceManager.Authorization.Models.RoleManagementPrincipalType.ServicePrincipal + }, ct), cts.Token)); + + private async Task CreatePostgreSqlServerAndDatabaseAsync(SubnetResource subnet, PrivateDnsZoneResource postgreSqlDnsZone) { - if (!subnet.Inner.Delegations.Any()) + subnet = await EnsureResourceDataAsync(subnet, r => r.HasData, r => ct => r.GetAsync(cancellationToken: ct), cts.Token); + + if (!subnet.Data.Delegations.Any()) { - subnet.Parent.Update().UpdateSubnet(subnet.Name).WithDelegation("Microsoft.DBforPostgreSQL/flexibleServers"); - await subnet.Parent.Update().ApplyAsync(); + subnet.Data.Delegations.Add(new() { ServiceName = "Microsoft.DBforPostgreSQL/flexibleServers" }); + await subnet.UpdateAsync(WaitUntil.Completed, subnet.Data, cts.Token); } - FlexibleServerModel.Server server = null; + PostgreSqlFlexibleServerData data = new(new(configuration.RegionName)) + { + Version = new(configuration.PostgreSqlVersion), + Sku = new(configuration.PostgreSqlSkuName, configuration.PostgreSqlTier), + StorageSizeInGB = configuration.PostgreSqlStorageSize, + AdministratorLogin = configuration.PostgreSqlAdministratorLogin, + AdministratorLoginPassword = configuration.PostgreSqlAdministratorPassword, + Network = new() + { + /*PublicNetworkAccess = PostgreSqlFlexibleServerPublicNetworkAccessState.Disabled,*/ + DelegatedSubnetResourceId = subnet.Id, + PrivateDnsZoneArmResourceId = postgreSqlDnsZone.Id + }, + HighAvailability = new() { Mode = PostgreSqlFlexibleServerHighAvailabilityMode.Disabled }, + }; - await Execute( + var server = await Execute( $"Creating Azure Flexible Server for PostgreSQL: {configuration.PostgreSqlServerName}...", - async () => - { - server = await postgresManagementClient.Servers.CreateAsync( - configuration.ResourceGroupName, configuration.PostgreSqlServerName, - new( - location: configuration.RegionName, - version: configuration.PostgreSqlVersion, - sku: new(configuration.PostgreSqlSkuName, configuration.PostgreSqlTier), - storage: new(configuration.PostgreSqlStorageSize), - administratorLogin: configuration.PostgreSqlAdministratorLogin, - administratorLoginPassword: configuration.PostgreSqlAdministratorPassword, - network: new(publicNetworkAccess: "Disabled", delegatedSubnetResourceId: subnet.Inner.Id, privateDnsZoneArmResourceId: postgreSqlDnsZone.Id), - highAvailability: new("Disabled") - )); - }); + async () => (await resourceGroup.GetPostgreSqlFlexibleServers().CreateOrUpdateAsync(WaitUntil.Completed, configuration.PostgreSqlServerName, data, cts.Token)).Value); await Execute( $"Creating PostgreSQL tes database: {configuration.PostgreSqlTesDatabaseName}...", - () => postgresManagementClient.Databases.CreateAsync( - configuration.ResourceGroupName, configuration.PostgreSqlServerName, configuration.PostgreSqlTesDatabaseName, - new())); + () => server.GetPostgreSqlFlexibleServerDatabases().CreateOrUpdateAsync(WaitUntil.Completed, configuration.PostgreSqlTesDatabaseName, new(), cts.Token)); return server; } @@ -1580,7 +1591,7 @@ private string GetCreateTesUserString() private Task ExecuteQueriesOnAzurePostgreSQLDbFromK8(IKubernetes kubernetesClient, string podName, string aksNamespace) => Execute( - $"Executing scripts on postgresql...", + "Executing scripts on postgresql...", async () => { var tesScript = GetCreateTesUserString(); @@ -1601,190 +1612,169 @@ private Task ExecuteQueriesOnAzurePostgreSQLDbFromK8(IKubernetes kubernetesClien await kubernetesManager.ExecuteCommandsOnPodAsync(kubernetesClient, podName, commands, aksNamespace); }); - private Task AssignVmAsContributorToAppInsightsAsync(IIdentity managedIdentity, IResource appInsights) + private Task AssignVmAsContributorToAppInsightsAsync(UserAssignedIdentityResource managedIdentity, ArmResource appInsights) => Execute( - $"Assigning {BuiltInRole.Contributor} role for user-managed identity to App Insights resource scope...", + "Assigning 'Contributor' role for user-managed identity to App Insights resource scope...", () => roleAssignmentHashConflictRetryPolicy.ExecuteAsync( - ct => azureSubscriptionClient.AccessManagement.RoleAssignments - .Define(Guid.NewGuid().ToString()) - .ForObjectId(managedIdentity.PrincipalId) - .WithBuiltInRole(BuiltInRole.Contributor) - .WithResourceScope(appInsights) - .CreateAsync(ct), - cts.Token)); - - private Task<(INetwork virtualNetwork, ISubnet vmSubnet, ISubnet postgreSqlSubnet, ISubnet batchSubnet)> CreateVnetAndSubnetsAsync(IResourceGroup resourceGroup) + ct => (Task)appInsights.GetRoleAssignments().CreateOrUpdateAsync(WaitUntil.Completed, Guid.NewGuid().ToString(), + new(All_Role_Contributor, managedIdentity.Data.PrincipalId.Value) + { + PrincipalType = Azure.ResourceManager.Authorization.Models.RoleManagementPrincipalType.ServicePrincipal + }, ct), cts.Token)); + + private Task<(VirtualNetworkResource virtualNetwork, SubnetResource vmSubnet, SubnetResource postgreSqlSubnet, SubnetResource batchSubnet)> CreateVnetAndSubnetsAsync() => Execute( $"Creating virtual network and subnets: {configuration.VnetName}...", async () => { - var tesPorts = new List { }; + List tesPorts = []; if (configuration.EnableIngress.GetValueOrDefault()) { tesPorts = [80, 443]; } - var defaultNsg = await CreateNetworkSecurityGroupAsync(resourceGroup, $"{configuration.VnetName}-default-nsg"); - var aksNsg = await CreateNetworkSecurityGroupAsync(resourceGroup, $"{configuration.VnetName}-aks-nsg", tesPorts); + var defaultNsg = (await EnsureResourceDataAsync(await CreateNetworkSecurityGroupAsync($"{configuration.VnetName}-default-nsg"), nsg => nsg.HasData, nsg => ct => nsg.GetAsync(cancellationToken: ct), cts.Token)).Data; + var aksNsg = (await EnsureResourceDataAsync(await CreateNetworkSecurityGroupAsync($"{configuration.VnetName}-aks-nsg", tesPorts), nsg => nsg.HasData, nsg => ct => nsg.GetAsync(cancellationToken: ct), cts.Token)).Data; - var vnetDefinition = azureSubscriptionClient.Networks - .Define(configuration.VnetName) - .WithRegion(configuration.RegionName) - .WithExistingResourceGroup(resourceGroup) - .WithAddressSpace(configuration.VnetAddressSpace) - .DefineSubnet(configuration.VmSubnetName) - .WithAddressPrefix(configuration.VmSubnetAddressSpace) - .WithExistingNetworkSecurityGroup(aksNsg) - .Attach(); + VirtualNetworkData vnetDefinition = new() { Location = new(configuration.RegionName) }; + vnetDefinition.AddressPrefixes.Add(configuration.VnetAddressSpace); - vnetDefinition = vnetDefinition.DefineSubnet(configuration.PostgreSqlSubnetName) - .WithAddressPrefix(configuration.PostgreSqlSubnetAddressSpace) - .WithExistingNetworkSecurityGroup(defaultNsg) - .WithDelegation("Microsoft.DBforPostgreSQL/flexibleServers") - .Attach(); - - vnetDefinition = vnetDefinition.DefineSubnet(configuration.BatchSubnetName) - .WithAddressPrefix(configuration.BatchNodesSubnetAddressSpace) - .WithExistingNetworkSecurityGroup(defaultNsg) - .Attach(); - - var vnet = await vnetDefinition.CreateAsync(cts.Token); - var batchSubnet = vnet.Subnets.FirstOrDefault(s => s.Key.Equals(configuration.BatchSubnetName, StringComparison.OrdinalIgnoreCase)).Value; + vnetDefinition.Subnets.Add(new() + { + Name = configuration.VmSubnetName, + AddressPrefix = configuration.VmSubnetAddressSpace, + NetworkSecurityGroup = aksNsg, + }); - // Use the new ResourceManager sdk to add the ACR service endpoint since it is absent from the fluent sdk. - var armBatchSubnet = (await armClient.GetSubnetResource(new ResourceIdentifier(batchSubnet.Inner.Id)).GetAsync(cancellationToken: cts.Token)).Value; + SubnetData postgreSqlSubnet = new() + { + Name = configuration.PostgreSqlSubnetName, + AddressPrefix = configuration.PostgreSqlSubnetAddressSpace, + NetworkSecurityGroup = defaultNsg, + }; + postgreSqlSubnet.Delegations.Add(NewServiceDelegation("Microsoft.DBforPostgreSQL/flexibleServers")); + vnetDefinition.Subnets.Add(postgreSqlSubnet); - AddServiceEndpointsToSubnet(armBatchSubnet.Data); + SubnetData batchSubnet = new() + { + Name = configuration.BatchSubnetName, + AddressPrefix = configuration.BatchNodesSubnetAddressSpace, + NetworkSecurityGroup = defaultNsg, + }; + AddServiceEndpointsToSubnet(batchSubnet); + vnetDefinition.Subnets.Add(batchSubnet); - await armBatchSubnet.UpdateAsync(Azure.WaitUntil.Completed, armBatchSubnet.Data, cts.Token); + var vnet = (await resourceGroup.GetVirtualNetworks().CreateOrUpdateAsync(WaitUntil.Completed, configuration.VnetName, vnetDefinition, cts.Token)).Value; + var subnets = await vnet.GetSubnets().ToListAsync(cts.Token); return (vnet, - vnet.Subnets.FirstOrDefault(s => s.Key.Equals(configuration.VmSubnetName, StringComparison.OrdinalIgnoreCase)).Value, - vnet.Subnets.FirstOrDefault(s => s.Key.Equals(configuration.PostgreSqlSubnetName, StringComparison.OrdinalIgnoreCase)).Value, - batchSubnet); + subnets.FirstOrDefault(s => s.Id.Name.Equals(configuration.VmSubnetName, StringComparison.OrdinalIgnoreCase)), + subnets.FirstOrDefault(s => s.Id.Name.Equals(configuration.PostgreSqlSubnetName, StringComparison.OrdinalIgnoreCase)), + subnets.FirstOrDefault(s => s.Id.Name.Equals(configuration.BatchSubnetName, StringComparison.OrdinalIgnoreCase))); + + static ServiceDelegation NewServiceDelegation(string serviceDelegation) => + new() { Name = serviceDelegation, ServiceName = serviceDelegation }; }); - private Task CreateNetworkSecurityGroupAsync(IResourceGroup resourceGroup, string networkSecurityGroupName, List openPorts = null) + private async Task CreateNetworkSecurityGroupAsync(string networkSecurityGroupName, IEnumerable openPorts = null) { - var icreate = azureSubscriptionClient.NetworkSecurityGroups.Define(networkSecurityGroupName) - .WithRegion(configuration.RegionName) - .WithExistingResourceGroup(resourceGroup); + NetworkSecurityGroupData data = new() { Location = new(configuration.RegionName) }; if (openPorts is not null) { - var i = 0; - foreach (var port in openPorts) + foreach (var (port, i) in openPorts.Select((p, i) => (p, i))) { - icreate = icreate - .DefineRule($"ALLOW-{port}") - .AllowInbound() - .FromAnyAddress() - .FromAnyPort() - .ToAnyAddress() - .ToPort(port) - .WithAnyProtocol() - .WithPriority(1000 + i) - .Attach(); - i++; + data.SecurityRules.Add(new() + { + Name = $"ALLOW-{port}", + Access = SecurityRuleAccess.Allow, + Direction = SecurityRuleDirection.Inbound, + SourceAddressPrefix = "*", + SourcePortRange = "*", + DestinationAddressPrefix = "*", + DestinationPortRange = port.ToString(System.Globalization.CultureInfo.InvariantCulture), + Protocol = SecurityRuleProtocol.Asterisk, + Priority = 1000 + i, + }); } } - return icreate.CreateAsync(cts.Token); + return (await resourceGroup.GetNetworkSecurityGroups().CreateOrUpdateAsync(WaitUntil.Completed, networkSecurityGroupName, data, cts.Token)).Value; } - private Task CreatePrivateDnsZoneAsync(INetwork virtualNetwork, string name, string title) + private Task CreatePrivateDnsZoneAsync(VirtualNetworkResource virtualNetwork, string name, string title) => Execute( $"Creating private DNS Zone for {title}...", async () => { - // Note: for a potential future implementation of this method without Fluent, - // please see commit cbffa28 in #392 - var dnsZone = await azureSubscriptionClient.PrivateDnsZones - .Define(name) - .WithExistingResourceGroup(configuration.ResourceGroupName) - .DefineVirtualNetworkLink($"{virtualNetwork.Name}-link") - .WithReferencedVirtualNetworkId(virtualNetwork.Id) - .DisableAutoRegistration() - .Attach() - .CreateAsync(cts.Token); + var dnsZone = (await resourceGroup.GetPrivateDnsZones() + .CreateOrUpdateAsync(WaitUntil.Completed, name, new(new("global")), cancellationToken: cts.Token)).Value; + VirtualNetworkLinkData data = new(new("global")) + { + VirtualNetworkId = virtualNetwork.Id, + RegistrationEnabled = false + }; + _ = await dnsZone.GetVirtualNetworkLinks().CreateOrUpdateAsync(WaitUntil.Completed, $"{virtualNetwork.Id.Name}-link", data, cancellationToken: cts.Token); return dnsZone; }); - private async Task SetStorageKeySecret(string vaultUrl, string secretName, string secretValue) + private async Task SetStorageKeySecret(Uri vaultUrl, string secretName, string secretValue) { - var client = new SecretClient(new(vaultUrl), new DefaultAzureCredential(new DefaultAzureCredentialOptions { AuthorityHost = new Uri(azureCloudConfig.Authentication.LoginEndpointUrl) })); + var client = new SecretClient(vaultUrl, tokenCredential); await client.SetSecretAsync(secretName, secretValue, cts.Token); } - private Task GetKeyVaultAsync(string vaultName) + private async Task GetKeyVaultAsync(string vaultName) { - var keyVaultManagementClient = new KeyVaultManagementClient(azureCredentials) { SubscriptionId = configuration.SubscriptionId }; - return keyVaultManagementClient.Vaults.GetAsync(configuration.ResourceGroupName, vaultName, cts.Token); + return resourceGroup is null + ? (await armSubscription.GetKeyVaultsAsync(cancellationToken: cts.Token).FirstOrDefaultAsync(r => r.Id.ResourceGroupName.Equals(configuration.ResourceGroupName, StringComparison.OrdinalIgnoreCase), cts.Token)) + : (await resourceGroup.GetKeyVaultAsync(vaultName, cts.Token)).Value; } - private Task CreateKeyVaultAsync(string vaultName, IIdentity managedIdentity, ISubnet subnet) + private Task CreateKeyVaultAsync(string vaultName, UserAssignedIdentityResource managedIdentity, VirtualNetworkResource virtualNetwork, SubnetResource subnet) => Execute( $"Creating Key Vault: {vaultName}...", async () => { - var tenantId = managedIdentity.TenantId; - var secrets = new List + if (!managedIdentity.HasData) { - "get", - "list", - "set", - "delete", - "backup", - "restore", - "recover", - "purge" - }; + throw new ArgumentException("Resource data has not been fetched.", nameof(managedIdentity)); + } - var keyVaultManagementClient = new KeyVaultManagementClient(azureCredentials) { SubscriptionId = configuration.SubscriptionId }; - var properties = new VaultCreateOrUpdateParameters() + var tenantId = managedIdentity.Data.TenantId; + IdentityAccessPermissions permissions = new(); + permissions.Secrets.Add(IdentityAccessSecretPermission.Get); + permissions.Secrets.Add(IdentityAccessSecretPermission.List); + permissions.Secrets.Add(IdentityAccessSecretPermission.Set); + permissions.Secrets.Add(IdentityAccessSecretPermission.Delete); + permissions.Secrets.Add(IdentityAccessSecretPermission.Backup); + permissions.Secrets.Add(IdentityAccessSecretPermission.Restore); + permissions.Secrets.Add(IdentityAccessSecretPermission.Recover); + permissions.Secrets.Add(IdentityAccessSecretPermission.Purge); + + KeyVaultProperties properties = new(tenantId.Value, new(KeyVaultSkuFamily.A, KeyVaultSkuName.Standard)) { - Location = configuration.RegionName, - Properties = new() + NetworkRuleSet = new() { - TenantId = new(tenantId), - Sku = new(SkuName.Standard), - NetworkAcls = new() - { - DefaultAction = configuration.PrivateNetworking.GetValueOrDefault() ? "Deny" : "Allow" - }, - AccessPolicies = - [ - new() - { - TenantId = new(tenantId), - ObjectId = await GetUserObjectId(), - Permissions = new() - { - Secrets = secrets - } - }, - new() - { - TenantId = new(tenantId), - ObjectId = managedIdentity.PrincipalId, - Permissions = new() - { - Secrets = secrets - } - } - ] - } + DefaultAction = configuration.PrivateNetworking.GetValueOrDefault() ? KeyVaultNetworkRuleAction.Deny : KeyVaultNetworkRuleAction.Allow + }, }; - var vault = await keyVaultManagementClient.Vaults.CreateOrUpdateAsync(configuration.ResourceGroupName, vaultName, properties, cts.Token); + properties.AccessPolicies.AddRange( + [ + new(tenantId.Value, await GetUserObjectId(), permissions), + new(tenantId.Value, managedIdentity.Data.PrincipalId.Value.ToString("D"), permissions), + ]); + + var vault = (await resourceGroup.GetKeyVaults().CreateOrUpdateAsync(WaitUntil.Completed, vaultName, new(new(configuration.RegionName), properties), cts.Token)).Value; if (configuration.PrivateNetworking.GetValueOrDefault()) { var connection = new NetworkPrivateLinkServiceConnection { Name = "pe-coa-keyvault", - PrivateLinkServiceId = new(vault.Id) + PrivateLinkServiceId = vault.Id }; connection.GroupIds.Add("vault"); @@ -1792,125 +1782,146 @@ private Task CreateKeyVaultAsync(string vaultName, IIdentity managedIdent { CustomNetworkInterfaceName = "pe-coa-keyvault", ExtendedLocation = new() { Name = configuration.RegionName }, - Subnet = new() { Id = new(subnet.Inner.Id), Name = subnet.Name } + Subnet = new() { Id = subnet.Id, Name = subnet.Id.Name } }; endpointData.PrivateLinkServiceConnections.Add(connection); - var privateEndpoint = (await armClient - .GetResourceGroupResource(new ResourceIdentifier(subnet.Parent.Inner.Id).Parent) + var privateEndpoint = (await resourceGroup .GetPrivateEndpoints() - .CreateOrUpdateAsync(Azure.WaitUntil.Completed, "pe-keyvault", endpointData, cts.Token)) + .CreateOrUpdateAsync(WaitUntil.Completed, "pe-keyvault", endpointData, cts.Token)) .Value.Data; var networkInterface = privateEndpoint.NetworkInterfaces[0]; - var dnsZone = await CreatePrivateDnsZoneAsync(subnet.Parent, $"privatelink.{azureCloudConfig.Suffixes.KeyVaultDnsSuffix}", "KeyVault"); - await dnsZone - .Update() - .DefineARecordSet(vault.Name) - .WithIPv4Address(networkInterface.IPConfigurations.First().PrivateIPAddress) - .Attach() - .ApplyAsync(cts.Token); + var dnsZone = await CreatePrivateDnsZoneAsync(virtualNetwork, "privatelink.vaultcore.azure.net", "KeyVault"); + PrivateDnsARecordData aRecordData = new(); + aRecordData.PrivateDnsARecords.Add(new() + { + IPv4Address = IPAddress.Parse(networkInterface.IPConfigurations.First(c => NetworkIPVersion.IPv4.Equals(c.PrivateIPAddressVersion)).PrivateIPAddress) + }); + _ = await dnsZone + .GetPrivateDnsARecords() + .CreateOrUpdateAsync(WaitUntil.Completed, vault.Id.Name, aRecordData, cancellationToken: cts.Token); } return vault; async ValueTask GetUserObjectId() { - const string graphUri = "https://graph.windows.net//.default"; - var credentials = new AzureCredentials(default, new TokenCredentials(new RefreshableAzureServiceTokenProvider(graphUri)), tenantId, AzureEnvironment.AzureGlobalCloud); - using GraphRbacManagementClient rbacClient = new(Configure().WithEnvironment(AzureEnvironment.AzureGlobalCloud).WithCredentials(credentials).WithBaseUri(graphUri).Build()) { TenantID = tenantId }; - credentials.InitializeServiceClient(rbacClient); - return (await rbacClient.SignedInUser.GetAsync(cts.Token)).ObjectId; + string baseUrl; + { + using var client = GraphClientFactory.Create(nationalCloud: NationalCloud()); + baseUrl = client.BaseAddress.AbsoluteUri; + } + { + using var client = new GraphServiceClient(tokenCredential, baseUrl: baseUrl); + return (await client.Me.GetAsync(cancellationToken: cts.Token)).Id; + } + } + + // Note that there are two different values for USGovernment. + string NationalCloud() + { + if (cloudEnvironment.ArmEnvironment.Endpoint == ArmEnvironment.AzurePublicCloud.Endpoint) + { + return GraphClientFactory.Global_Cloud; + } + + if (cloudEnvironment.ArmEnvironment.Endpoint == ArmEnvironment.AzureChina.Endpoint) + { + return GraphClientFactory.China_Cloud; + } + + if (cloudEnvironment.ArmEnvironment.Endpoint == ArmEnvironment.AzureGovernment.Endpoint) + { + return GraphClientFactory.USGOV_Cloud; // TODO: when should we return GraphClientFactory.USGOV_DOD_Cloud? + } + + return GraphClientFactory.Global_Cloud; } }); - private Task CreateLogAnalyticsWorkspaceResourceAsync(string workspaceName) + private Task CreateLogAnalyticsWorkspaceResourceAsync(string workspaceName) => Execute( $"Creating Log Analytics Workspace: {workspaceName}...", - () => ResourceManager - .Configure() - .Authenticate(azureCredentials) - .WithSubscription(configuration.SubscriptionId) - .GenericResources.Define(workspaceName) - .WithRegion(configuration.RegionName) - .WithExistingResourceGroup(configuration.ResourceGroupName) - .WithResourceType("workspaces") - .WithProviderNamespace("Microsoft.OperationalInsights") - .WithoutPlan() - .WithApiVersion("2020-08-01") - .WithParentResource(string.Empty) - .CreateAsync(cts.Token)); - - private Task CreateAppInsightsResourceAsync(string logAnalyticsArmId) + async () => + { + OperationalInsightsWorkspaceData data = new(new(configuration.RegionName)); + return (await resourceGroup.GetOperationalInsightsWorkspaces() + .CreateOrUpdateAsync(WaitUntil.Completed, workspaceName, data, cts.Token)).Value; + }); + + private Task CreateAppInsightsResourceAsync(ResourceIdentifier logAnalyticsArmId) => Execute( $"Creating Application Insights: {configuration.ApplicationInsightsAccountName}...", - () => ResourceManager - .Configure() - .Authenticate(azureCredentials) - .WithSubscription(configuration.SubscriptionId) - .GenericResources.Define(configuration.ApplicationInsightsAccountName) - .WithRegion(configuration.RegionName) - .WithExistingResourceGroup(configuration.ResourceGroupName) - .WithResourceType("components") - .WithProviderNamespace("microsoft.insights") - .WithoutPlan() - .WithApiVersion("2020-02-02") - .WithParentResource(string.Empty) - .WithProperties(new Dictionary() { - { "Application_Type", "other" } , - { "WorkspaceResourceId", logAnalyticsArmId } - }) - .CreateAsync(cts.Token)); - - private Task CreateBatchAccountAsync(string storageAccountId) + async () => + { + ApplicationInsightsComponentData data = new(new(configuration.RegionName), "other") + { + FlowType = ComponentFlowType.Bluefield, + RequestSource = ComponentRequestSource.Rest, + ApplicationType = ApplicationInsightsApplicationType.Other, + WorkspaceResourceId = logAnalyticsArmId, + }; + return (await resourceGroup.GetApplicationInsightsComponents() + .CreateOrUpdateAsync(WaitUntil.Completed, configuration.ApplicationInsightsAccountName, data, cts.Token)).Value; + }); + + private Task CreateBatchAccountAsync(ResourceIdentifier storageAccountId) => Execute( $"Creating Batch Account: {configuration.BatchAccountName}...", - () => new BatchManagementClient(tokenCredentials) { SubscriptionId = configuration.SubscriptionId, BaseUri = new Uri(azureCloudConfig.ResourceManagerUrl) } - .BatchAccount - .CreateAsync( - configuration.ResourceGroupName, - configuration.BatchAccountName, - new( - configuration.RegionName, - autoStorage: configuration.PrivateNetworking.GetValueOrDefault() ? new() { StorageAccountId = storageAccountId } : null), - cts.Token)); - - private Task CreateResourceGroupAsync() + async () => + { + Batch.BatchAccountCreateOrUpdateContent data = new(new(configuration.RegionName)) + { + AutoStorage = configuration.PrivateNetworking.GetValueOrDefault() ? new(storageAccountId) : null, + }; + return (await resourceGroup.GetBatchAccounts() + .CreateOrUpdateAsync(WaitUntil.Completed, configuration.BatchAccountName, data, cts.Token)).Value; + }); + + private Task CreateResourceGroupAsync() { var tags = !string.IsNullOrWhiteSpace(configuration.Tags) ? Utility.DelimitedTextToDictionary(configuration.Tags, "=", ",") : null; - var resourceGroupDefinition = azureSubscriptionClient - .ResourceGroups - .Define(configuration.ResourceGroupName) - .WithRegion(configuration.RegionName); - - resourceGroupDefinition = tags is not null ? resourceGroupDefinition.WithTags(tags) : resourceGroupDefinition; + ResourceGroupData data = new(new(configuration.RegionName)); + (tags ?? []).ForEach(data.Tags.Add); return Execute( $"Creating Resource Group: {configuration.ResourceGroupName}...", - () => resourceGroupDefinition.CreateAsync(cts.Token)); + async () => (await armSubscription.GetResourceGroups().CreateOrUpdateAsync(WaitUntil.Completed, configuration.ResourceGroupName, data, cts.Token)).Value); } - private Task CreateUserManagedIdentityAsync(IResourceGroup resourceGroup) + private Task CreateUserManagedIdentityAsync() { // Resource group name supports periods and parenthesis but identity doesn't. Replacing them with hyphens. - var managedIdentityName = $"{resourceGroup.Name.Replace(".", "-").Replace("(", "-").Replace(")", "-")}-identity"; + var managedIdentityName = $"{resourceGroup.Id.Name.Replace(".", "-").Replace("(", "-").Replace(")", "-")}-identity"; return Execute( $"Obtaining user-managed identity: {managedIdentityName}...", - async () => await azureSubscriptionClient.Identities.GetByResourceGroupAsync(configuration.ResourceGroupName, managedIdentityName) - ?? await azureSubscriptionClient.Identities.Define(managedIdentityName) - .WithRegion(configuration.RegionName) - .WithExistingResourceGroup(resourceGroup) - .CreateAsync(cts.Token)); + async () => + { + try + { + return (await resourceGroup.GetUserAssignedIdentityAsync(managedIdentityName, cts.Token)).Value; + } + catch (RequestFailedException ex) when (ex.Status == (int)HttpStatusCode.NotFound) + { + return (await resourceGroup.GetUserAssignedIdentities().CreateOrUpdateAsync( + WaitUntil.Completed, + managedIdentityName, + new(new(configuration.RegionName)), + cts.Token)) + .Value; + } + }); } - private async Task DeleteResourceGroupAsync() + private async Task DeleteResourceGroupAsync(CancellationToken cancellationToken) { var startTime = DateTime.UtcNow; var line = ConsoleEx.WriteLine("Deleting resource group..."); - await azureSubscriptionClient.ResourceGroups.DeleteByNameAsync(configuration.ResourceGroupName, CancellationToken.None); + await resourceGroup.DeleteAsync(WaitUntil.Completed, cancellationToken: cancellationToken); WriteExecutionTime(line, startTime); } @@ -1929,9 +1940,12 @@ private static void ValidateMainIdentifierPrefix(string prefix) } } - private void ValidateRegionName(string regionName) + private async Task ValidateRegionNameAsync(string regionName) { - var validRegionNames = azureSubscriptionClient.GetCurrentSubscription().ListLocations().Select(loc => loc.Region.Name).Distinct(); + // GetAvailableLocations*() does not work https://github.com/Azure/azure-sdk-for-net/issues/28914 + var validRegionNames = await armSubscription.GetLocationsAsync(cancellationToken: cts.Token) + .Where(x => x.Metadata.RegionType == RegionType.Physical) + .Select(loc => loc.Name).Distinct().ToListAsync(cts.Token); if (!validRegionNames.Contains(regionName, StringComparer.OrdinalIgnoreCase)) { @@ -1944,33 +1958,27 @@ private async Task ValidateSubscriptionAndResourceGroupAsync(Configuration confi const string ownerRoleId = "8e3af657-a8ff-443c-a75c-2fe8c4bcb635"; const string contributorRoleId = "b24988ac-6180-42a0-ab88-20f7382dd24c"; - var azure = Microsoft.Azure.Management.Fluent.Azure - .Configure() - .WithLogLevel(HttpLoggingDelegatingHandler.Level.Basic) - .Authenticate(azureCredentials); + bool rgExists; - var subscriptionExists = await (await azure.Subscriptions.ListAsync(cancellationToken: cts.Token)).ToAsyncEnumerable() - .AnyAsync(sub => sub.SubscriptionId.Equals(configuration.SubscriptionId, StringComparison.OrdinalIgnoreCase), cts.Token); - - if (!subscriptionExists) + try { - throw new ValidationException($"Invalid or inaccessible subcription id '{configuration.SubscriptionId}'. Make sure that subscription exists and that you are either an Owner or have Contributor and User Access Administrator roles on the subscription.", displayExample: false); + rgExists = !string.IsNullOrEmpty(configuration.ResourceGroupName) && (await armSubscription.GetResourceGroups().ExistsAsync(configuration.ResourceGroupName, cts.Token)).Value; + } + catch (Exception) + { + throw new ValidationException($"Invalid or inaccessible subscription id '{configuration.SubscriptionId}'. Make sure that subscription exists and that you are either an Owner or have Contributor and User Access Administrator roles on the subscription.", displayExample: false); } - - var rgExists = !string.IsNullOrEmpty(configuration.ResourceGroupName) && await azureSubscriptionClient.ResourceGroups.ContainAsync(configuration.ResourceGroupName, cts.Token); if (!string.IsNullOrEmpty(configuration.ResourceGroupName) && !rgExists) { throw new ValidationException($"If ResourceGroupName is provided, the resource group must already exist.", displayExample: false); } - var token = (await tokenProvider.GetAuthenticationHeaderAsync(cts.Token)).Parameter; - var currentPrincipalObjectId = new JwtSecurityTokenHandler().ReadJwtToken(token).Claims.FirstOrDefault(c => c.Type == "oid").Value; + var token = (await tokenCredential.GetTokenAsync(new([cloudEnvironment.ArmEnvironment.DefaultScope]), cts.Token)); + var currentPrincipalObjectId = new JwtSecurityTokenHandler().ReadJwtToken(token.Token).Claims.FirstOrDefault(c => c.Type == "oid").Value; - var currentPrincipalSubscriptionRoleIds = (await azureSubscriptionClient.AccessManagement.RoleAssignments.Inner.ListForScopeWithHttpMessagesAsync( - $"/subscriptions/{configuration.SubscriptionId}", new($"atScope() and assignedTo('{currentPrincipalObjectId}')"), cancellationToken: cts.Token)).Body - .ToAsyncEnumerable(async (link, ct) => (await azureSubscriptionClient.AccessManagement.RoleAssignments.Inner.ListForScopeNextWithHttpMessagesAsync(link, cancellationToken: ct)).Body) - .Select(b => b.RoleDefinitionId.Split(['/']).Last()); + var currentPrincipalSubscriptionRoleIds = armSubscription.GetRoleAssignments().GetAllAsync($"atScope() and assignedTo('{currentPrincipalObjectId}')", cancellationToken: cts.Token) + .SelectAwaitWithCancellation(async (b, c) => await FetchResourceDataAsync(t => b.GetAsync(cancellationToken: t), c)).Select(b => b.Data.RoleDefinitionId.Name); if (!await currentPrincipalSubscriptionRoleIds.AnyAsync(role => ownerRoleId.Equals(role, StringComparison.OrdinalIgnoreCase) || contributorRoleId.Equals(role, StringComparison.OrdinalIgnoreCase), cts.Token)) { @@ -1979,10 +1987,8 @@ private async Task ValidateSubscriptionAndResourceGroupAsync(Configuration confi throw new ValidationException($"Insufficient access to deploy. You must be: 1) Owner of the subscription, or 2) Contributor and User Access Administrator of the subscription, or 3) Owner of the resource group", displayExample: false); } - var currentPrincipalRgRoleIds = (await azureSubscriptionClient.AccessManagement.RoleAssignments.Inner.ListForScopeWithHttpMessagesAsync( - $"/subscriptions/{configuration.SubscriptionId}/resourceGroups/{configuration.ResourceGroupName}", new($"atScope() and assignedTo('{currentPrincipalObjectId}')"), cancellationToken: cts.Token)).Body - .ToAsyncEnumerable(async (link, ct) => (await azureSubscriptionClient.AccessManagement.RoleAssignments.Inner.ListForScopeNextWithHttpMessagesAsync(link, cancellationToken: ct)).Body) - .Select(b => b.RoleDefinitionId.Split(['/']).Last()); + var currentPrincipalRgRoleIds = resourceGroup.GetRoleAssignments().GetAllAsync($"atScope() and assignedTo('{currentPrincipalObjectId}')", cancellationToken: cts.Token) + .SelectAwaitWithCancellation(async (b, c) => await FetchResourceDataAsync(t => b.GetAsync(cancellationToken: t), c)).Select(b => b.Data.RoleDefinitionId.Name); if (!await currentPrincipalRgRoleIds.AnyAsync(role => ownerRoleId.Equals(role, StringComparison.OrdinalIgnoreCase), cts.Token)) { @@ -1991,7 +1997,7 @@ private async Task ValidateSubscriptionAndResourceGroupAsync(Configuration confi } } - private async Task ValidateAndGetExistingStorageAccountAsync() + private async Task ValidateAndGetExistingStorageAccountAsync() { if (configuration.StorageAccountName is null) { @@ -2002,7 +2008,7 @@ private async Task ValidateAndGetExistingStorageAccountAsync() ?? throw new ValidationException($"If StorageAccountName is provided, the storage account must already exist in region {configuration.RegionName}, and be accessible to the current user.", displayExample: false); } - private async Task ValidateAndGetExistingBatchAccountAsync() + private async Task ValidateAndGetExistingBatchAccountAsync() { if (configuration.BatchAccountName is null) { @@ -2013,7 +2019,7 @@ private async Task ValidateAndGetExistingBatchAccountAsync() ?? throw new ValidationException($"If BatchAccountName is provided, the batch account must already exist in region {configuration.RegionName}, and be accessible to the current user.", displayExample: false); } - private async Task<(INetwork virtualNetwork, ISubnet vmSubnet, ISubnet postgreSqlSubnet, ISubnet batchSubnet)?> ValidateAndGetExistingVirtualNetworkAsync() + private async Task<(VirtualNetworkResource virtualNetwork, SubnetResource vmSubnet, SubnetResource postgreSqlSubnet, SubnetResource batchSubnet)?> ValidateAndGetExistingVirtualNetworkAsync() { static bool AllOrNoneSet(params string[] values) => values.All(v => !string.IsNullOrEmpty(v)) || values.All(v => string.IsNullOrEmpty(v)); static bool NoneSet(params string[] values) => values.All(v => string.IsNullOrEmpty(v)); @@ -2038,58 +2044,56 @@ private async Task ValidateAndGetExistingBatchAccountAsync() throw new ValidationException($"{nameof(configuration.VnetResourceGroupName)}, {nameof(configuration.VnetName)} and {nameof(configuration.VmSubnetName)} are required when using an existing virtual network."); } - if (!await (await azureSubscriptionClient.ResourceGroups.ListAsync(true, cts.Token)).ToAsyncEnumerable().AnyAsync(rg => rg.Name.Equals(configuration.VnetResourceGroupName, StringComparison.OrdinalIgnoreCase), cts.Token)) + if (!await armSubscription.GetResourceGroups().GetAllAsync(cancellationToken: cts.Token).AnyAsync(rg => rg.Id.Name.Equals(configuration.VnetResourceGroupName, StringComparison.OrdinalIgnoreCase), cts.Token)) { throw new ValidationException($"Resource group '{configuration.VnetResourceGroupName}' does not exist."); } - var vnet = await azureSubscriptionClient.Networks.GetByResourceGroupAsync(configuration.VnetResourceGroupName, configuration.VnetName, cts.Token) ?? + var vnet = (await (await armSubscription.GetResourceGroupAsync(configuration.VnetResourceGroupName, cts.Token)).Value.GetVirtualNetworks().GetIfExistsAsync(configuration.VnetName, cancellationToken: cts.Token)).Value ?? throw new ValidationException($"Virtual network '{configuration.VnetName}' does not exist in resource group '{configuration.VnetResourceGroupName}'."); - if (!vnet.RegionName.Equals(configuration.RegionName, StringComparison.OrdinalIgnoreCase)) + if (!(await FetchResourceDataAsync(ct => vnet.GetAsync(cancellationToken: ct), cts.Token, net => vnet = net)).Data.Location.Value.Name.Equals(configuration.RegionName, StringComparison.OrdinalIgnoreCase)) { throw new ValidationException($"Virtual network '{configuration.VnetName}' must be in the same region that you are deploying to ({configuration.RegionName})."); } - var vmSubnet = vnet.Subnets.FirstOrDefault(s => s.Key.Equals(configuration.VmSubnetName, StringComparison.OrdinalIgnoreCase)).Value ?? + var vmSubnet = await vnet.GetSubnets().GetAllAsync(cts.Token).FirstOrDefaultAsync(s => s.Id.Name.Equals(configuration.VmSubnetName, StringComparison.OrdinalIgnoreCase), cts.Token) ?? throw new ValidationException($"Virtual network '{configuration.VnetName}' does not contain subnet '{configuration.VmSubnetName}'"); - var resourceGraphClient = new ResourceGraphClient(tokenCredentials); - var postgreSqlSubnet = vnet.Subnets.FirstOrDefault(s => s.Key.Equals(configuration.PostgreSqlSubnetName, StringComparison.OrdinalIgnoreCase)).Value; - - if (postgreSqlSubnet is null) - { + var postgreSqlSubnet = await vnet.GetSubnets().GetAllAsync(cts.Token).FirstOrDefaultAsync(s => s.Id.Name.Equals(configuration.PostgreSqlSubnetName, StringComparison.OrdinalIgnoreCase), cts.Token) ?? throw new ValidationException($"Virtual network '{configuration.VnetName}' does not contain subnet '{configuration.PostgreSqlSubnetName}'"); - } - var delegatedServices = postgreSqlSubnet.Inner.Delegations.Select(d => d.ServiceName); + postgreSqlSubnet = await FetchResourceDataAsync(ct => postgreSqlSubnet.GetAsync(cancellationToken: ct), cts.Token); + var delegatedServices = postgreSqlSubnet.Data.Delegations.Select(d => d.ServiceName).ToList(); var hasOtherDelegations = delegatedServices.Any(s => s != "Microsoft.DBforPostgreSQL/flexibleServers"); - var hasNoDelegations = !delegatedServices.Any(); + var hasNoDelegations = 0 == delegatedServices.Count; if (hasOtherDelegations) { throw new ValidationException($"Subnet '{configuration.PostgreSqlSubnetName}' can have 'Microsoft.DBforPostgreSQL/flexibleServers' delegation only."); } - var resourcesInPostgreSqlSubnetQuery = $"where type =~ 'Microsoft.Network/networkInterfaces' | where properties.ipConfigurations[0].properties.subnet.id == '{postgreSqlSubnet.Inner.Id}'"; - var resourcesExist = (await resourceGraphClient.ResourcesAsync(new([configuration.SubscriptionId], resourcesInPostgreSqlSubnetQuery), cts.Token)).TotalRecords > 0; + Azure.ResourceManager.ResourceGraph.Models.ResourceQueryContent resourcesInPostgreSqlSubnetQuery = new($"where type =~ 'Microsoft.Network/networkInterfaces' | where properties.ipConfigurations[0].properties.subnet.id == '{postgreSqlSubnet.Id}'"); + resourcesInPostgreSqlSubnetQuery.Subscriptions.Add(configuration.SubscriptionId); + var resourcesExist = (await (await armClient.GetTenants().GetAllAsync(cts.Token).FirstAsync(cts.Token)).GetResourcesAsync(resourcesInPostgreSqlSubnetQuery, cts.Token)).Value.TotalRecords > 0; if (hasNoDelegations && resourcesExist) { throw new ValidationException($"Subnet '{configuration.PostgreSqlSubnetName}' must be either empty or have 'Microsoft.DBforPostgreSQL/flexibleServers' delegation."); } - var batchSubnet = vnet.Subnets.FirstOrDefault(s => s.Key.Equals(configuration.BatchSubnetName, StringComparison.OrdinalIgnoreCase)).Value; + var batchSubnet = await vnet.GetSubnets().GetAllAsync(cts.Token).FirstOrDefaultAsync(s => s.Id.Name.Equals(configuration.BatchSubnetName, StringComparison.OrdinalIgnoreCase), cts.Token) ?? + throw new ValidationException($"Virtual network '{configuration.VnetName}' does not contain subnet '{configuration.BatchSubnetName}'"); return (vnet, vmSubnet, postgreSqlSubnet, batchSubnet); } private async Task ValidateBatchAccountQuotaAsync() { - var batchManagementClient = new BatchManagementClient(tokenCredentials) { SubscriptionId = configuration.SubscriptionId, BaseUri = new Uri(azureCloudConfig.ResourceManagerUrl) }; - var accountQuota = (await batchManagementClient.Location.GetQuotasAsync(configuration.RegionName, cts.Token)).AccountQuota; - var existingBatchAccountCount = await (await batchManagementClient.BatchAccount.ListAsync(cts.Token)).ToAsyncEnumerable(batchManagementClient.BatchAccount.ListNextAsync) - .CountAsync(b => b.Location.Equals(configuration.RegionName), cts.Token); + var accountQuota = (await armSubscription.GetBatchQuotasAsync(new(configuration.RegionName), cts.Token)).Value.AccountQuota; + var existingBatchAccountCount = await armSubscription.GetBatchAccountsAsync(cts.Token) + .SelectAwaitWithCancellation(async (a, t) => await FetchResourceDataAsync(a.GetAsync, cts.Token)) + .CountAsync(b => b.Data.Location.Value.Name.Equals(configuration.RegionName), cts.Token); if (existingBatchAccountCount >= accountQuota) { @@ -2097,14 +2101,12 @@ private async Task ValidateBatchAccountQuotaAsync() } } - private Task UpdateVnetWithBatchSubnet(string resourceGroupId) + private Task UpdateVnetWithBatchSubnet() => Execute( $"Creating batch subnet...", async () => { - var coaRg = armClient.GetResourceGroupResource(new(resourceGroupId)); - - var vnetCollection = coaRg.GetVirtualNetworks(); + var vnetCollection = resourceGroup.GetVirtualNetworks(); var vnet = vnetCollection.FirstOrDefault(); if (vnetCollection.Count() != 1) @@ -2146,22 +2148,22 @@ private Task UpdateVnetWithBatchSubnet(string resourceGroupId) private static void AddServiceEndpointsToSubnet(SubnetData subnet) { - subnet.ServiceEndpoints.Add(new ServiceEndpointProperties() + subnet.ServiceEndpoints.Add(new() { Service = "Microsoft.Storage.Global", }); - subnet.ServiceEndpoints.Add(new ServiceEndpointProperties() + subnet.ServiceEndpoints.Add(new() { Service = "Microsoft.Sql", }); - subnet.ServiceEndpoints.Add(new ServiceEndpointProperties() + subnet.ServiceEndpoints.Add(new() { Service = "Microsoft.ContainerRegistry", }); - subnet.ServiceEndpoints.Add(new ServiceEndpointProperties() + subnet.ServiceEndpoints.Add(new() { Service = "Microsoft.KeyVault", }); @@ -2170,17 +2172,16 @@ private static void AddServiceEndpointsToSubnet(SubnetData subnet) private async Task ValidateVmAsync() { var computeSkus = await generalRetryPolicy.ExecuteAsync(async ct => - await (await azureSubscriptionClient.ComputeSkus.ListbyRegionAndResourceTypeAsync( - Region.Create(configuration.RegionName), - ComputeResourceType.VirtualMachines, - ct)) - .ToAsyncEnumerable() + await armSubscription.GetComputeResourceSkusAsync( + filter: $"location eq '{configuration.RegionName}'", + cancellationToken: ct) + .Where(s => "virtualMachines".Equals(s.ResourceType, StringComparison.OrdinalIgnoreCase)) .Where(s => !s.Restrictions.Any()) - .Select(s => s.Name.Value) + .Select(s => s.Name) .ToListAsync(ct), cts.Token); - if (!computeSkus.Any()) + if (0 == computeSkus.Count) { throw new ValidationException($"Your subscription doesn't support virtual machine creation in {configuration.RegionName}. Please create an Azure Support case: https://docs.microsoft.com/en-us/azure/azure-portal/supportability/how-to-create-azure-support-request", displayExample: false); } @@ -2190,18 +2191,15 @@ private async Task ValidateVmAsync() } } - private static async Task GetBlobClientAsync(IStorageAccount storageAccount, CancellationToken cancellationToken) - => new( - new($"https://{storageAccount.Name}.blob.{azureCloudConfig.Suffixes.StorageSuffix}"), - new StorageSharedKeyCredential( - storageAccount.Name, - (await storageAccount.GetKeysAsync(cancellationToken))[0].Value)); - private async Task ValidateTokenProviderAsync() { try { - _ = await Execute("Retrieving Azure management token...", () => new AzureServiceTokenProvider("RunAs=Developer; DeveloperTool=AzureCli").GetAccessTokenAsync(azureCloudConfig.ResourceManagerUrl, cancellationToken: cts.Token)); + _ = await Execute("Retrieving Azure management token...", + async () => await new AzureCliCredential(new() + { + AuthorityHost = cloudEnvironment.AzureAuthorityHost + }).GetTokenAsync(new([cloudEnvironment.ArmEnvironment.DefaultScope]), cancellationToken: cts.Token)); } catch (AuthenticationFailedException ex) { @@ -2386,7 +2384,9 @@ private async Task DeleteResourceGroupIfUserConsentsAsync() if (userResponse.Equals("yes", StringComparison.OrdinalIgnoreCase) || (configuration.Silent && configuration.DeleteResourceGroupOnFailure)) { - await DeleteResourceGroupAsync(); + using var token = new CancellationTokenSource(); + Console.CancelKeyPress += (o, a) => token.Cancel(true); + await DeleteResourceGroupAsync(token.Token); } } @@ -2441,21 +2441,15 @@ private async Task Execute(string message, Func> func, bool cancel private static void WriteExecutionTime(ConsoleEx.Line line, DateTime startTime) => line.Write($" Completed in {DateTime.UtcNow.Subtract(startTime).TotalSeconds:n0}s", ConsoleColor.Green); - public static async Task DownloadTextFromStorageAccountAsync(IStorageAccount storageAccount, string containerName, string blobName, CancellationToken cancellationToken) + public static async Task DownloadTextFromStorageAccountAsync(BlobClient blobClient, CancellationToken cancellationToken) { - var blobClient = await GetBlobClientAsync(storageAccount, cancellationToken); - var container = blobClient.GetBlobContainerClient(containerName); - - return (await container.GetBlobClient(blobName).DownloadContentAsync(cancellationToken)).Value.Content.ToString(); + return (await blobClient.DownloadContentAsync(cancellationToken)).Value.Content.ToString(); } - public static async Task UploadTextToStorageAccountAsync(IStorageAccount storageAccount, string containerName, string blobName, string content, CancellationToken token) + public static async Task UploadTextToStorageAccountAsync(BlobClient blobClient, string content, CancellationToken cancellationToken) { - var blobClient = await GetBlobClientAsync(storageAccount, token); - var container = blobClient.GetBlobContainerClient(containerName); - - await container.CreateIfNotExistsAsync(cancellationToken: token); - await container.GetBlobClient(blobName).UploadAsync(BinaryData.FromString(content), true, token); + await blobClient.GetParentBlobContainerClient().CreateIfNotExistsAsync(cancellationToken: cancellationToken); + await blobClient.UploadAsync(BinaryData.FromString(content), true, cancellationToken); } private class ValidationException(string reason, bool displayExample = true) : Exception diff --git a/src/deploy-tes-on-azure/KubernetesManager.cs b/src/deploy-tes-on-azure/KubernetesManager.cs index fe538b9af..75581f59b 100644 --- a/src/deploy-tes-on-azure/KubernetesManager.cs +++ b/src/deploy-tes-on-azure/KubernetesManager.cs @@ -11,15 +11,13 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Azure.ResourceManager; +using Azure.ResourceManager.ContainerService; +using Azure.ResourceManager.ManagedServiceIdentities; +using Azure.Storage.Blobs; using CommonUtilities.AzureCloud; using k8s; using k8s.Models; -using Microsoft.Azure.Management.ContainerService; -using Microsoft.Azure.Management.Msi.Fluent; -using Microsoft.Azure.Management.ResourceManager.Fluent; -using Microsoft.Azure.Management.ResourceManager.Fluent.Authentication; -using Microsoft.Azure.Management.ResourceManager.Fluent.Core; -using Microsoft.Azure.Management.Storage.Fluent; using Polly; using Polly.Retry; @@ -43,25 +41,27 @@ public class KubernetesManager private const string CertManagerRepo = "https://charts.jetstack.io"; private const string CertManagerVersion = "v1.12.3"; - private Configuration configuration { get; set; } - private AzureCredentials azureCredentials { get; set; } - private AzureCloudConfig azureCloudConfig { get; set; } - private CancellationToken cancellationToken { get; set; } + private Configuration configuration { get; } + private AzureCloudConfig azureEndpoints { get; } + private CancellationToken cancellationToken { get; } private string workingDirectoryTemp { get; set; } private string kubeConfigPath { get; set; } private string valuesTemplatePath { get; set; } - public string helmScriptsRootDirectory { get; set; } - public string TempHelmValuesYamlPath { get; set; } - public string TesCname { get; set; } + public string helmScriptsRootDirectory { get; private set; } + public string TempHelmValuesYamlPath { get; private set; } + public string TesCname { get; private set; } public string TesHostname { get; set; } - public string AzureDnsLabelName { get; set; } + public string AzureDnsLabelName { get; private set; } - public KubernetesManager(Configuration config, AzureCredentials credentials, AzureCloudConfig azureCloudConfig, CancellationToken cancellationToken) + public delegate BlobClient GetBlobClient(Azure.ResourceManager.Storage.StorageAccountData storageAccount, string containerName, string blobName); + private readonly GetBlobClient getBlobClient; + + public KubernetesManager(Configuration config, AzureCloudConfig azureCloudConfig, GetBlobClient getBlobClient, CancellationToken cancellationToken) { - this.azureCloudConfig = azureCloudConfig; + this.azureEndpoints = azureCloudConfig; this.cancellationToken = cancellationToken; configuration = config; - azureCredentials = credentials; + this.getBlobClient = getBlobClient; CreateAndInitializeWorkingDirectoriesAsync().Wait(cancellationToken); } @@ -69,22 +69,20 @@ public KubernetesManager(Configuration config, AzureCredentials credentials, Azu public void SetTesIngressNetworkingConfiguration(string prefix) { const int maxCnLength = 64; - var suffix = $".{configuration.RegionName}.cloudapp.{azureCloudConfig.Domain}"; + var suffix = $".{configuration.RegionName}.cloudapp.{azureEndpoints.Domain}"; var prefixMaxLength = maxCnLength - suffix.Length; TesCname = GetTesCname(prefix, prefixMaxLength); TesHostname = $"{TesCname}{suffix}"; AzureDnsLabelName = TesCname; } - public async Task GetKubernetesClientAsync(IResource resourceGroupObject) + public async Task GetKubernetesClientAsync(ContainerServiceManagedClusterResource managedCluster) { - var resourceGroup = resourceGroupObject.Name; - var containerServiceClient = new ContainerServiceClient(azureCredentials) { SubscriptionId = configuration.SubscriptionId, BaseUri = new Uri(azureCloudConfig.ResourceManagerUrl) }; + using MemoryStream kubeconfig = new((await managedCluster.GetClusterAdminCredentialsAsync(cancellationToken: cancellationToken)).Value.Kubeconfigs[0].Value, writable: false); - // Write kubeconfig in the working directory, because KubernetesClientConfiguration needs to read from a file, TODO figure out how to pass this directly. - var creds = await containerServiceClient.ManagedClusters.ListClusterAdminCredentialsAsync(resourceGroup, configuration.AksClusterName, cancellationToken: cancellationToken); + // Write kubeconfig in the working directory as helm & kubctl need it. var kubeConfigFile = new FileInfo(kubeConfigPath); - await File.WriteAllTextAsync(kubeConfigFile.FullName, Encoding.Default.GetString(creds.Kubeconfigs.First().Value), cancellationToken); + await File.WriteAllTextAsync(kubeConfigFile.FullName, Encoding.Default.GetString(kubeconfig.ToArray()), cancellationToken); kubeConfigFile.Refresh(); if (!OperatingSystem.IsWindows()) @@ -92,8 +90,7 @@ public async Task GetKubernetesClientAsync(IResource resourceGroupO kubeConfigFile.UnixFileMode = UnixFileMode.UserRead | UnixFileMode.UserWrite; } - var k8sConfiguration = KubernetesClientConfiguration.LoadKubeConfig(kubeConfigFile, false); - var k8sClientConfiguration = KubernetesClientConfiguration.BuildConfigFromConfigObject(k8sConfiguration); + var k8sClientConfiguration = KubernetesClientConfiguration.BuildConfigFromConfigFile(kubeconfig); return new Kubernetes(k8sClientConfiguration); } @@ -244,18 +241,18 @@ public async Task GetHelmValuesAsync(string valuesTemplatePath) return values; } - public async Task UpdateHelmValuesAsync(IStorageAccount storageAccount, string keyVaultUrl, string resourceGroupName, Dictionary settings, IIdentity managedId) + public async Task UpdateHelmValuesAsync(Azure.ResourceManager.Storage.StorageAccountData storageAccount, Uri keyVaultUrl, string resourceGroupName, Dictionary settings, UserAssignedIdentityData managedId) { var values = await GetHelmValuesAsync(valuesTemplatePath); UpdateValuesFromSettings(values, settings); values.Config["resourceGroup"] = resourceGroupName; values.Identity["name"] = managedId.Name; - values.Identity["resourceId"] = managedId.Id; - values.Identity["clientId"] = managedId.ClientId; + values.Identity["resourceId"] = managedId.Id.ToString(); + values.Identity["clientId"] = managedId.ClientId?.ToString("D"); if (configuration.CrossSubscriptionAKSDeployment.GetValueOrDefault()) { - values.InternalContainersKeyVaultAuth = new List>(); + values.InternalContainersKeyVaultAuth = []; foreach (var container in values.DefaultContainers) { @@ -263,7 +260,7 @@ public async Task UpdateHelmValuesAsync(IStorageAccount storageAccount, string k { { "accountName", storageAccount.Name }, { "containerName", container }, - { "keyVaultURL", keyVaultUrl }, + { "keyVaultURL", keyVaultUrl.AbsoluteUri }, { "keyVaultSecretName", Deployer.StorageAccountKeySecretName} }; @@ -272,7 +269,7 @@ public async Task UpdateHelmValuesAsync(IStorageAccount storageAccount, string k } else { - values.InternalContainersMIAuth = new List>(); + values.InternalContainersMIAuth = []; foreach (var container in values.DefaultContainers) { @@ -289,16 +286,17 @@ public async Task UpdateHelmValuesAsync(IStorageAccount storageAccount, string k var valuesString = KubernetesYaml.Serialize(values); await File.WriteAllTextAsync(TempHelmValuesYamlPath, valuesString, cancellationToken); - await Deployer.UploadTextToStorageAccountAsync(storageAccount, Deployer.ConfigurationContainerName, "aksValues.yaml", valuesString, cancellationToken); + await Deployer.UploadTextToStorageAccountAsync(getBlobClient(storageAccount, Deployer.ConfigurationContainerName, "aksValues.yaml"), valuesString, cancellationToken); } - public async Task UpgradeValuesYamlAsync(IStorageAccount storageAccount, Dictionary settings) + public async Task UpgradeValuesYamlAsync(Azure.ResourceManager.Storage.StorageAccountData storageAccount, Dictionary settings) { - var values = KubernetesYaml.Deserialize(await Deployer.DownloadTextFromStorageAccountAsync(storageAccount, Deployer.ConfigurationContainerName, "aksValues.yaml", cancellationToken)); + var blobClient = getBlobClient(storageAccount, Deployer.ConfigurationContainerName, "aksValues.yaml"); + var values = KubernetesYaml.Deserialize(await Deployer.DownloadTextFromStorageAccountAsync(blobClient, cancellationToken)); UpdateValuesFromSettings(values, settings); var valuesString = KubernetesYaml.Serialize(values); await File.WriteAllTextAsync(TempHelmValuesYamlPath, valuesString, cancellationToken); - await Deployer.UploadTextToStorageAccountAsync(storageAccount, Deployer.ConfigurationContainerName, "aksValues.yaml", valuesString, cancellationToken); + await Deployer.UploadTextToStorageAccountAsync(blobClient, valuesString, cancellationToken); } public async Task ConfigureAltLocalValuesYamlAsync(string altName, Action configure) @@ -322,9 +320,9 @@ public void RestoreLocalValuesYaml(FileInfo backup) File.Replace(backup.FullName, TempHelmValuesYamlPath, default); } - public async Task> GetAKSSettingsAsync(IStorageAccount storageAccount) + public async Task> GetAKSSettingsAsync(Azure.ResourceManager.Storage.StorageAccountData storageAccount) { - var values = KubernetesYaml.Deserialize(await Deployer.DownloadTextFromStorageAccountAsync(storageAccount, Deployer.ConfigurationContainerName, "aksValues.yaml", cancellationToken)); + var values = KubernetesYaml.Deserialize(await Deployer.DownloadTextFromStorageAccountAsync(getBlobClient(storageAccount, Deployer.ConfigurationContainerName, "aksValues.yaml"), cancellationToken)); return ValuesToSettings(values); } @@ -539,7 +537,7 @@ private static Dictionary ValuesToSettings(HelmValues values) /// private static string GetTesCname(string prefix, int maxLength = 40) { - var tempCname = SdkContext.RandomResourceName($"{prefix.Replace(".", "")}-", maxLength); + var tempCname = Utility.RandomResourceName($"{prefix.Replace(".", "")}-", maxLength); if (tempCname.Length > maxLength) { diff --git a/src/deploy-tes-on-azure/Program.cs b/src/deploy-tes-on-azure/Program.cs index 5d9537b35..a4d957e60 100644 --- a/src/deploy-tes-on-azure/Program.cs +++ b/src/deploy-tes-on-azure/Program.cs @@ -5,6 +5,7 @@ using System.Threading.Tasks; using TesDeployer; +System.Threading.Thread.CurrentThread.CurrentCulture = System.Globalization.CultureInfo.InvariantCulture; await InitializeAndDeployAsync(args); static async Task InitializeAndDeployAsync(string[] args) diff --git a/src/deploy-tes-on-azure/Utility.cs b/src/deploy-tes-on-azure/Utility.cs index 6335ec499..07db2ed1b 100644 --- a/src/deploy-tes-on-azure/Utility.cs +++ b/src/deploy-tes-on-azure/Utility.cs @@ -13,6 +13,16 @@ namespace TesDeployer { public static class Utility { + /// + /// Generates a random resource names with the prefix. + /// + /// the prefix to be used if possible + /// the maximum length for the random generated name + /// random name + /// Implementation of Microsoft.Azure.Management.ResourceManager.Fluent.SdkContext.RandomResourceName + public static string RandomResourceName(string prefix, int maxLength) + => new ResourceNamer(string.Empty).RandomName(prefix, maxLength); + public static string DictionaryToDelimitedText(Dictionary dictionary, string fieldDelimiter = "=", string rowDelimiter = "\n") => string.Join(rowDelimiter, dictionary.Select(kv => $"{kv.Key}{fieldDelimiter}{kv.Value}")); @@ -134,5 +144,41 @@ public static string GetFileContent(params string[] pathComponentsRelativeToAppB private static Stream GetBinaryFileContent(params string[] pathComponentsRelativeToAppBase) => typeof(Deployer).Assembly.GetManifestResourceStream($"deploy-tes-on-azure.{string.Join(".", pathComponentsRelativeToAppBase)}"); + + // borrowed from https://github.com/Azure/azure-libraries-for-net/blob/7d85e294e4e7280c3f74b1c41438e2f20bce2052/src/ResourceManagement/ResourceManager/ResourceNamer.cs + private class ResourceNamer(string name) + { + private readonly string randName = name.ToLowerInvariant() + Guid.NewGuid().ToString("N")[..3].ToLowerInvariant(); + private static readonly Random random = new(); + + public string RandomName(string prefix, int maxLen) + { + lock (random) // https://learn.microsoft.com/dotnet/fundamentals/runtime-libraries/system-random#thread-safety + { + prefix = prefix.ToLowerInvariant(); + var minRandomnessLength = 5; + var minRandomString = random.Next(0, 100000).ToString("D5"); + + if (maxLen < (prefix.Length + randName.Length + minRandomnessLength)) + { + var str1 = prefix + minRandomString; + return str1 + RandomString((maxLen - str1.Length) / 2); + } + + var str = prefix + randName + minRandomString; + return str + RandomString((maxLen - str.Length) / 2); + } + } + + private static string RandomString(int length) + { + var str = ""; + while (str.Length < length) + { + str += Guid.NewGuid().ToString("N")[..Math.Min(32, length)].ToLowerInvariant(); + } + return str; + } + } } } diff --git a/src/deploy-tes-on-azure/deploy-tes-on-azure.csproj b/src/deploy-tes-on-azure/deploy-tes-on-azure.csproj index 46894ee87..aa565aa4e 100644 --- a/src/deploy-tes-on-azure/deploy-tes-on-azure.csproj +++ b/src/deploy-tes-on-azure/deploy-tes-on-azure.csproj @@ -18,19 +18,24 @@ - - - + + + + + + + + + + + + + + - + - - - - - - - + @@ -40,11 +45,11 @@ + - - +