diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/CloudConnection.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/CloudConnection.cs index 53e3e3ce911..595cad4bf10 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/CloudConnection.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/CloudConnection.cs @@ -26,6 +26,7 @@ class CloudConnection : ICloudConnection const int TokenTimeToLiveSeconds = 3600; // Unused - Token is generated by downstream clients const int TokenExpiryBufferPercentage = 8; // Assuming a standard token for 1 hr, we set expiry time to around 5 mins. const uint OperationTimeoutMilliseconds = 1 * 60 * 1000; // 1 min + static readonly TimeSpan TokenRetryWaitTime = TimeSpan.FromSeconds(20); readonly Action connectionStatusChangedHandler; readonly ITransportSettings[] transportSettingsList; @@ -103,15 +104,15 @@ public async Task CreateOrUpdateAsync(IClientCredentials newCredent if (newCredentials is ITokenCredentials tokenAuth && this.tokenGetter.HasValue) { if (IsTokenExpired(tokenAuth.Identity.IotHubHostName, tokenAuth.Token)) - { + { throw new InvalidOperationException($"Token for client {tokenAuth.Identity.Id} is expired"); } this.tokenGetter.ForEach(tg => { - tg.SetResult(tokenAuth.Token); + // First reset the token getter and then set the result. this.tokenGetter = Option.None>(); - Events.NewTokenObtained(newCredentials.Identity.IotHubHostName, newCredentials.Identity.Id, tokenAuth.Token); + tg.SetResult(tokenAuth.Token); }); return (cp, false); } @@ -188,7 +189,7 @@ async Task CreateDeviceClient( { client.SetProductInfo(newCredentials.ProductInfo); } - + Events.CreateDeviceClientSuccess(transportSettings.GetTransportType(), OperationTimeoutMilliseconds, newCredentials.Identity); return client; } @@ -274,36 +275,63 @@ void InternalConnectionStatusChangesHandler(ConnectionStatus status, ConnectionS /// If the existing identity has a usable token, then use it. /// Else, generate a notification of token being near expiry and return a task that /// can be completed later. + /// Keep retrying till we get a usable token. /// Note - Don't use this.Identity in this method, as it may not have been set yet! /// async Task GetNewToken(string iotHub, string id, string currentToken, IIdentity currentIdentity) { Events.GetNewToken(id); - // We have to catch UnauthorizedAccessException, because on IsTokenUsable, we call parse from - // Device Client and it throws if the token is expired. - if (IsTokenUsable(iotHub, currentToken)) + bool retrying = false; + string token = currentToken; + while (true) { - Events.UsingExistingToken(id); - return currentToken; - } - else - { - Events.TokenExpired(id, currentToken); - } + // We have to catch UnauthorizedAccessException, because on IsTokenUsable, we call parse from + // Device Client and it throws if the token is expired. + if (IsTokenUsable(iotHub, token)) + { + if (retrying) + { + Events.NewTokenObtained(iotHub, id, token); + } + else + { + Events.UsingExistingToken(id); + } + return token; + } + else + { + Events.TokenNotUsable(iotHub, id, token); + } - // No need to lock here as the lock is being held by the refresher. - TaskCompletionSource tcs = this.tokenGetter - .GetOrElse( - () => + bool newTokenGetterCreated = false; + // No need to lock here as the lock is being held by the refresher. + TaskCompletionSource tcs = this.tokenGetter + .GetOrElse( + () => + { + Events.SafeCreateNewToken(id); + var taskCompletionSource = new TaskCompletionSource(); + this.tokenGetter = Option.Some(taskCompletionSource); + newTokenGetterCreated = true; + return taskCompletionSource; + }); + + // If a new tokenGetter was created, then invoke the connection status changed handler + if (newTokenGetterCreated) + { + // If retrying, wait for some time. + if (retrying) { - Events.SafeCreateNewToken(id); - var taskCompletionSource = new TaskCompletionSource(); - this.tokenGetter = Option.Some(taskCompletionSource); - this.connectionStatusChangedHandler(currentIdentity.Id, CloudConnectionStatus.TokenNearExpiry); - return taskCompletionSource; - }); - string newToken = await tcs.Task; - return newToken; + await Task.Delay(TokenRetryWaitTime); + } + this.connectionStatusChangedHandler(currentIdentity.Id, CloudConnectionStatus.TokenNearExpiry); + } + + retrying = true; + // this.tokenGetter will be reset when this task returns. + token = await tcs.Task; + } } internal static DateTime GetTokenExpiry(string hostName, string token) @@ -427,7 +455,6 @@ enum EventIds CreateNewToken, UpdatedCloudConnection, ObtainedNewToken, - TokenExpired, ErrorRenewingToken, ErrorCheckingTokenUsability } @@ -486,11 +513,6 @@ internal static void NewTokenObtained(string hostname, string id, string newToke Log.LogInformation((int)EventIds.ObtainedNewToken, Invariant($"Obtained new token for client {id} that expires in {timeRemaining}")); } - internal static void TokenExpired(string id, string currentToken) - { - Log.LogDebug((int)EventIds.TokenExpired, Invariant($"Token Expired. Id:{id}, CurrentToken: {currentToken}.")); - } - internal static void ErrorRenewingToken(Exception ex) { Log.LogDebug((int)EventIds.ErrorRenewingToken, ex, "Critical Error trying to renew Token."); @@ -500,6 +522,12 @@ public static void ErrorCheckingTokenUsable(Exception ex) { Log.LogDebug((int)EventIds.ErrorCheckingTokenUsability, ex, "Error checking if token is usable."); } + + public static void TokenNotUsable(string hostname, string id, string newToken) + { + TimeSpan timeRemaining = GetTokenExpiryTimeRemaining(hostname, newToken); + Log.LogDebug((int)EventIds.ObtainedNewToken, Invariant($"Token received for client {id} expires in {timeRemaining}, and so is not usable. Getting a fresh token...")); + } } } } diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.CloudProxy.Test/CloudConnectionTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.CloudProxy.Test/CloudConnectionTest.cs index 90594f9a3aa..9770d585183 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.CloudProxy.Test/CloudConnectionTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.CloudProxy.Test/CloudConnectionTest.cs @@ -29,7 +29,7 @@ public void GetTokenExpiryBufferSecondsTest() string token = TokenHelper.CreateSasToken("azure.devices.net"); TimeSpan timeRemaining = CloudConnection.GetTokenExpiryTimeRemaining("foo.azuredevices.net", token); Assert.True(timeRemaining > TimeSpan.Zero); - } + } [Unit] [Fact] @@ -81,7 +81,14 @@ public async Task RefreshTokenTest() IClientCredentials GetClientCredentialsWithExpiringToken() { - string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddSeconds(10)); + string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddMinutes(3)); + var identity = new DeviceIdentity(iothubHostName, deviceId); + return new TokenCredentials(identity, token, string.Empty); + } + + IClientCredentials GetClientCredentialsWithNonExpiringToken() + { + string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddMinutes(10)); var identity = new DeviceIdentity(iothubHostName, deviceId); return new TokenCredentials(identity, token, string.Empty); } @@ -106,28 +113,105 @@ IClientCredentials GetClientCredentialsWithExpiringToken() var deviceAuthenticationWithTokenRefresh = authenticationMethod as DeviceAuthenticationWithTokenRefresh; Assert.NotNull(deviceAuthenticationWithTokenRefresh); - // Wait for the token to expire - await Task.Delay(TimeSpan.FromSeconds(10)); - Task getTokenTask = deviceAuthenticationWithTokenRefresh.GetTokenAsync(iothubHostName); Assert.False(getTokenTask.IsCompleted); Assert.Equal(receivedStatus, CloudConnectionStatus.TokenNearExpiry); - IClientCredentials clientCredentialsWithExpiringToken2 = GetClientCredentialsWithExpiringToken(); + IClientCredentials clientCredentialsWithExpiringToken2 = GetClientCredentialsWithNonExpiringToken(); ICloudProxy cloudProxy2 = await cloudConnection.CreateOrUpdateAsync(clientCredentialsWithExpiringToken2); // Wait for the task to complete await Task.Delay(TimeSpan.FromSeconds(10)); + Assert.True(getTokenTask.IsCompletedSuccessfully); Assert.Equal(cloudProxy2, cloudConnection.CloudProxy.OrDefault()); Assert.True(cloudProxy2.IsActive); Assert.True(cloudProxy1.IsActive); Assert.Equal(cloudProxy1, cloudProxy2); - Assert.True(getTokenTask.IsCompletedSuccessfully); Assert.Equal(getTokenTask.Result, (clientCredentialsWithExpiringToken2 as ITokenCredentials)?.Token); } + [Fact] + [Unit] + public async Task RefreshTokenWithRetryTest() + { + string iothubHostName = "test.azure-devices.net"; + string deviceId = "device1"; + + IClientCredentials GetClientCredentialsWithExpiringToken() + { + string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddMinutes(3)); + var identity = new DeviceIdentity(iothubHostName, deviceId); + return new TokenCredentials(identity, token, string.Empty); + } + + IClientCredentials GetClientCredentialsWithNonExpiringToken() + { + string token = TokenHelper.CreateSasToken(iothubHostName, DateTime.UtcNow.AddMinutes(10)); + var identity = new DeviceIdentity(iothubHostName, deviceId); + return new TokenCredentials(identity, token, string.Empty); + } + + IAuthenticationMethod authenticationMethod = null; + IClientProvider clientProvider = GetMockDeviceClientProviderWithToken((s, a, t) => authenticationMethod = a); + + var transportSettings = new ITransportSettings[] { new AmqpTransportSettings(TransportType.Amqp_Tcp_Only) }; + + var receivedStatuses = new List(); + void ConnectionStatusHandler(string id, CloudConnectionStatus status) => receivedStatuses.Add(status); + var messageConverterProvider = new MessageConverterProvider(new Dictionary { [typeof(TwinCollection)] = Mock.Of() }); + + var cloudConnection = new CloudConnection(ConnectionStatusHandler, transportSettings, messageConverterProvider, clientProvider, Mock.Of(), TokenProvider, DeviceScopeIdentitiesCache, TimeSpan.FromMinutes(60)); + + IClientCredentials clientCredentialsWithExpiringToken1 = GetClientCredentialsWithExpiringToken(); + ICloudProxy cloudProxy1 = await cloudConnection.CreateOrUpdateAsync(clientCredentialsWithExpiringToken1); + Assert.True(cloudProxy1.IsActive); + Assert.Equal(cloudProxy1, cloudConnection.CloudProxy.OrDefault()); + + Assert.NotNull(authenticationMethod); + var deviceAuthenticationWithTokenRefresh = authenticationMethod as DeviceAuthenticationWithTokenRefresh; + Assert.NotNull(deviceAuthenticationWithTokenRefresh); + + // Try to refresh token but get an expiring token + Task getTokenTask = deviceAuthenticationWithTokenRefresh.GetTokenAsync(iothubHostName); + Assert.False(getTokenTask.IsCompleted); + + Assert.Equal(2, receivedStatuses.Count); + Assert.Equal(receivedStatuses[1], CloudConnectionStatus.TokenNearExpiry); + + ICloudProxy cloudProxy2 = await cloudConnection.CreateOrUpdateAsync(clientCredentialsWithExpiringToken1); + + // Wait for the task to process + await Task.Delay(TimeSpan.FromSeconds(5)); + + Assert.False(getTokenTask.IsCompletedSuccessfully); + Assert.Equal(cloudProxy2, cloudConnection.CloudProxy.OrDefault()); + Assert.True(cloudProxy2.IsActive); + Assert.True(cloudProxy1.IsActive); + Assert.Equal(cloudProxy1, cloudProxy2); + + // Wait for 20 secs for retry to happen + await Task.Delay(TimeSpan.FromSeconds(20)); + + // Check if retry happened + Assert.Equal(3, receivedStatuses.Count); + Assert.Equal(receivedStatuses[2], CloudConnectionStatus.TokenNearExpiry); + + IClientCredentials clientCredentialsWithNonExpiringToken = GetClientCredentialsWithNonExpiringToken(); + ICloudProxy cloudProxy3 = await cloudConnection.CreateOrUpdateAsync(clientCredentialsWithNonExpiringToken); + + // Wait for the task to complete + await Task.Delay(TimeSpan.FromSeconds(5)); + + Assert.True(getTokenTask.IsCompletedSuccessfully); + Assert.Equal(cloudProxy3, cloudConnection.CloudProxy.OrDefault()); + Assert.True(cloudProxy3.IsActive); + Assert.True(cloudProxy1.IsActive); + Assert.Equal(cloudProxy1, cloudProxy3); + Assert.Equal(getTokenTask.Result, (clientCredentialsWithNonExpiringToken as ITokenCredentials)?.Token); + } + [Fact] [Unit] public async Task CloudConnectionCallbackTest() @@ -205,9 +289,9 @@ public async Task UpdateDeviceConnectionTest() string hostname = "dummy.azure-devices.net"; string deviceId = "device1"; - IClientCredentials GetClientCredentials() + IClientCredentials GetClientCredentials(TimeSpan tokenExpiryDuration) { - string token = TokenHelper.CreateSasToken(hostname, DateTime.UtcNow.AddSeconds(10)); + string token = TokenHelper.CreateSasToken(hostname, DateTime.UtcNow.AddSeconds(tokenExpiryDuration.TotalSeconds)); var identity = new DeviceIdentity(hostname, deviceId); return new TokenCredentials(identity, token, string.Empty); } @@ -258,7 +342,7 @@ IClient GetMockedDeviceClient() var credentialsCache = Mock.Of(); IConnectionManager connectionManager = new ConnectionManager(cloudConnectionProvider, credentialsCache, deviceId, "$edgeHub"); - IClientCredentials clientCredentials1 = GetClientCredentials(); + IClientCredentials clientCredentials1 = GetClientCredentials(TimeSpan.FromSeconds(10)); Try cloudProxyTry1 = await connectionManager.CreateCloudConnectionAsync(clientCredentials1); Assert.True(cloudProxyTry1.Success); @@ -272,7 +356,7 @@ IClient GetMockedDeviceClient() Task tokenGetter = deviceTokenRefresher.GetTokenAsync(hostname); Assert.False(tokenGetter.IsCompleted); - IClientCredentials clientCredentials2 = GetClientCredentials(); + IClientCredentials clientCredentials2 = GetClientCredentials(TimeSpan.FromMinutes(2)); Try cloudProxyTry2 = await connectionManager.CreateCloudConnectionAsync(clientCredentials2); Assert.True(cloudProxyTry2.Success); @@ -280,28 +364,18 @@ IClient GetMockedDeviceClient() await connectionManager.AddDeviceConnection(clientCredentials2.Identity, deviceProxy2); await Task.Delay(TimeSpan.FromSeconds(3)); - Assert.True(tokenGetter.IsCompleted); - Assert.Equal(tokenGetter.Result, (clientCredentials2 as ITokenCredentials)?.Token); - - await Task.Delay(TimeSpan.FromSeconds(10)); - Assert.NotNull(authenticationMethod); - deviceTokenRefresher = authenticationMethod as DeviceAuthenticationWithTokenRefresh; - Assert.NotNull(deviceTokenRefresher); - tokenGetter = deviceTokenRefresher.GetTokenAsync(hostname); Assert.False(tokenGetter.IsCompleted); - IClientCredentials clientCredentials3 = GetClientCredentials(); + IClientCredentials clientCredentials3 = GetClientCredentials(TimeSpan.FromMinutes(10)); Try cloudProxyTry3 = await connectionManager.CreateCloudConnectionAsync(clientCredentials3); Assert.True(cloudProxyTry3.Success); IDeviceProxy deviceProxy3 = GetMockDeviceProxy(); await connectionManager.AddDeviceConnection(clientCredentials3.Identity, deviceProxy3); - await Task.Delay(TimeSpan.FromSeconds(3)); + await Task.Delay(TimeSpan.FromSeconds(23)); Assert.True(tokenGetter.IsCompleted); Assert.Equal(tokenGetter.Result, (clientCredentials3 as ITokenCredentials)?.Token); - - Mock.VerifyAll(Mock.Get(deviceProxy1), Mock.Get(deviceProxy2)); } static async Task GetCloudConnectionTest(Func credentialsGenerator, IClientProvider clientProvider)