Skip to content

Commit

Permalink
Refresh cached credentials after PreAuthenticate fails (dotnet#101053)
Browse files Browse the repository at this point in the history
* Support refreshing credentials in pre-auth cache

* Fix minor bug in CredentialCache

* Add unit test

* Fix tests

* Fix tests attempt 2

* Merge two lock statements.

* Fix build
  • Loading branch information
rzikm authored and michaelgsharp committed May 8, 2024
1 parent 39f7495 commit 000c7c8
Show file tree
Hide file tree
Showing 10 changed files with 363 additions and 130 deletions.
133 changes: 133 additions & 0 deletions src/libraries/Common/src/System/Net/CredentialCacheKey.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;

namespace System.Net
{
internal sealed class CredentialCacheKey : IEquatable<CredentialCacheKey?>
{
public readonly Uri UriPrefix;
public readonly int UriPrefixLength = -1;
public readonly string AuthenticationType;

internal CredentialCacheKey(Uri uriPrefix, string authenticationType)
{
Debug.Assert(uriPrefix != null);
Debug.Assert(authenticationType != null);

UriPrefix = uriPrefix;
UriPrefixLength = UriPrefix.AbsolutePath.LastIndexOf('/');
AuthenticationType = authenticationType;
}

internal bool Match(Uri uri, string authenticationType)
{
if (uri == null || authenticationType == null)
{
return false;
}

// If the protocols don't match, this credential is not applicable for the given Uri.
if (!string.Equals(authenticationType, AuthenticationType, StringComparison.OrdinalIgnoreCase))
{
return false;
}

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"Match({UriPrefix} & {uri})");

return IsPrefix(uri, UriPrefix);
}

// IsPrefix (Uri)
//
// Determines whether <prefixUri> is a prefix of this URI. A prefix
// match is defined as:
//
// scheme match
// + host match
// + port match, if any
// + <prefix> path is a prefix of <URI> path, if any
//
// Returns:
// True if <prefixUri> is a prefix of this URI
private static bool IsPrefix(Uri uri, Uri prefixUri)
{
Debug.Assert(uri != null);
Debug.Assert(prefixUri != null);

if (prefixUri.Scheme != uri.Scheme || prefixUri.Host != uri.Host || prefixUri.Port != uri.Port)
{
return false;
}

int prefixLen = prefixUri.AbsolutePath.LastIndexOf('/');
if (prefixLen > uri.AbsolutePath.LastIndexOf('/'))
{
return false;
}

return string.Compare(uri.AbsolutePath, 0, prefixUri.AbsolutePath, 0, prefixLen, StringComparison.OrdinalIgnoreCase) == 0;
}

public override int GetHashCode() =>
StringComparer.OrdinalIgnoreCase.GetHashCode(AuthenticationType) ^
UriPrefix.GetHashCode();

public bool Equals([NotNullWhen(true)] CredentialCacheKey? other)
{
if (other == null)
{
return false;
}

bool equals =
string.Equals(AuthenticationType, other.AuthenticationType, StringComparison.OrdinalIgnoreCase) &&
UriPrefix.Equals(other.UriPrefix);

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"Equals({this},{other}) returns {equals}");

return equals;
}

public override bool Equals([NotNullWhen(true)] object? obj) => Equals(obj as CredentialCacheKey);

public override string ToString() =>
string.Create(CultureInfo.InvariantCulture, $"[{UriPrefixLength}]:{UriPrefix}:{AuthenticationType}");
}

internal static class CredentialCacheHelper
{
public static bool TryGetCredential(Dictionary<CredentialCacheKey, NetworkCredential> cache, Uri uriPrefix, string authType, [NotNullWhen(true)] out Uri? mostSpecificMatchUri, [NotNullWhen(true)] out NetworkCredential? mostSpecificMatch)
{
int longestMatchPrefix = -1;
mostSpecificMatch = null;
mostSpecificMatchUri = null;

// Enumerate through every credential in the cache
foreach ((CredentialCacheKey key, NetworkCredential value) in cache)
{
// Determine if this credential is applicable to the current Uri/AuthType
if (key.Match(uriPrefix, authType))
{
int prefixLen = key.UriPrefixLength;

// Check if the match is better than the current-most-specific match
if (prefixLen > longestMatchPrefix)
{
// Yes: update the information about currently preferred match
longestMatchPrefix = prefixLen;
mostSpecificMatch = value;
mostSpecificMatchUri = key.UriPrefix;
}
}
}

return mostSpecificMatch != null;
}
}
}
3 changes: 3 additions & 0 deletions src/libraries/System.Net.Http/src/System.Net.Http.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@
Link="Common\System\Text\ValueStringBuilder.AppendSpanFormattable.cs" />
<Compile Include="$(CommonPath)System\Obsoletions.cs"
Link="Common\System\Obsoletions.cs" />
<Compile Include="$(CommonPath)System\Net\CredentialCacheKey.cs"
Link="Common\System\Net\CredentialCacheKey.cs" />
</ItemGroup>

<!-- SocketsHttpHandler implementation -->
Expand Down Expand Up @@ -216,6 +218,7 @@
<Compile Include="System\Net\Http\SocketsHttpHandler\IHttpTrace.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\IMultiWebProxy.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\MultiProxy.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\PreAuthCredentialCache.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\RawConnectionStream.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\RedirectHandler.cs" />
<Compile Include="System\Net\Http\SocketsHttpHandler\SocketsHttpConnectionContext.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,25 +215,26 @@ private static async ValueTask<HttpResponseMessage> SendWithAuthAsync(HttpReques
// If preauth is enabled and this isn't proxy auth, try to get a basic credential from the
// preauth credentials cache, and if successful, set an auth header for it onto the request.
// Currently we only support preauth for Basic.
bool performedBasicPreauth = false;
NetworkCredential? preAuthCredential = null;
Uri? preAuthCredentialUri = null;
if (preAuthenticate)
{
Debug.Assert(pool.PreAuthCredentials != null);
NetworkCredential? credential;
(Uri uriPrefix, NetworkCredential credential)? preAuthCredentialPair;
lock (pool.PreAuthCredentials)
{
// Just look for basic credentials. If in the future we support preauth
// for other schemes, this will need to search in order of precedence.
Debug.Assert(pool.PreAuthCredentials.GetCredential(authUri, NegotiateScheme) == null);
Debug.Assert(pool.PreAuthCredentials.GetCredential(authUri, NtlmScheme) == null);
Debug.Assert(pool.PreAuthCredentials.GetCredential(authUri, DigestScheme) == null);
credential = pool.PreAuthCredentials.GetCredential(authUri, BasicScheme);
preAuthCredentialPair = pool.PreAuthCredentials.GetCredential(authUri, BasicScheme);
}

if (credential != null)
if (preAuthCredentialPair != null)
{
SetBasicAuthToken(request, credential, isProxyAuth);
performedBasicPreauth = true;
(preAuthCredentialUri, preAuthCredential) = preAuthCredentialPair.Value;
SetBasicAuthToken(request, preAuthCredential, isProxyAuth);
}
}

Expand Down Expand Up @@ -265,13 +266,21 @@ await TrySetDigestAuthToken(request, challenge.Credential, digestResponse, isPro
break;

case AuthenticationType.Basic:
if (performedBasicPreauth)
if (preAuthCredential != null)
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.AuthenticationError(authUri, $"Pre-authentication with {(isProxyAuth ? "proxy" : "server")} failed.");
}
break;

if (challenge.Credential == preAuthCredential)
{
// Pre auth failed, and user supplied credentials are still same, we can stop there.
break;
}

// Pre-auth credentials have changed, continue with the new ones.
// The old ones will be removed below.
}

response.Dispose();
Expand All @@ -293,6 +302,17 @@ await TrySetDigestAuthToken(request, challenge.Credential, digestResponse, isPro
default:
lock (pool.PreAuthCredentials!)
{
// remove previously cached (failing) creds
if (preAuthCredentialUri != null)
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(pool.PreAuthCredentials, $"Removing Basic credential from cache, uri={preAuthCredentialUri}, username={preAuthCredential!.UserName}");
}

pool.PreAuthCredentials.Remove(preAuthCredentialUri, BasicScheme);
}

try
{
if (NetEventSource.Log.IsEnabled())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ internal sealed partial class HttpConnectionPool : IDisposable
private SslClientAuthenticationOptions? _sslOptionsHttp3;
private readonly SslClientAuthenticationOptions? _sslOptionsProxy;

private readonly CredentialCache? _preAuthCredentials;
private readonly PreAuthCredentialCache? _preAuthCredentials;

/// <summary>Whether the pool has been used since the last time a cleanup occurred.</summary>
private bool _usedSinceLastCleanup = true;
Expand Down Expand Up @@ -237,7 +237,7 @@ public HttpConnectionPool(HttpConnectionPoolManager poolManager, HttpConnectionK
// Set up for PreAuthenticate. Access to this cache is guarded by a lock on the cache itself.
if (_poolManager.Settings._preAuthenticate)
{
_preAuthCredentials = new CredentialCache();
_preAuthCredentials = new PreAuthCredentialCache();
}

_http11RequestQueue = new RequestQueue<HttpConnection>();
Expand Down Expand Up @@ -296,7 +296,7 @@ private static SslClientAuthenticationOptions ConstructSslOptions(HttpConnection
public bool IsSecure => _kind == HttpConnectionKind.Https || _kind == HttpConnectionKind.SslProxyTunnel || _kind == HttpConnectionKind.SslSocksTunnel;
public Uri? ProxyUri => _proxyUri;
public ICredentials? ProxyCredentials => _poolManager.ProxyCredentials;
public CredentialCache? PreAuthCredentials => _preAuthCredentials;
public PreAuthCredentialCache? PreAuthCredentials => _preAuthCredentials;
public bool IsDefaultPort => OriginAuthority.Port == (IsSecure ? DefaultHttpsPort : DefaultHttpPort);
private bool DoProxyAuth => (_kind == HttpConnectionKind.Proxy || _kind == HttpConnectionKind.ProxyConnect);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;

namespace System.Net.Http
{
internal sealed class PreAuthCredentialCache
{
private Dictionary<CredentialCacheKey, NetworkCredential>? _cache;

public void Add(Uri uriPrefix, string authType, NetworkCredential cred)
{
Debug.Assert(uriPrefix != null);
Debug.Assert(authType != null);

var key = new CredentialCacheKey(uriPrefix, authType);

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"Adding key:[{key}], cred:[{cred.Domain}],[{cred.UserName}]");

_cache ??= new Dictionary<CredentialCacheKey, NetworkCredential>();
_cache.Add(key, cred);
}

public void Remove(Uri uriPrefix, string authType)
{
Debug.Assert(uriPrefix != null);
Debug.Assert(authType != null);

if (_cache == null)
{
return;
}

var key = new CredentialCacheKey(uriPrefix, authType);
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"Removing key:[{key}]");
_cache.Remove(key);
}

public (Uri uriPrefix, NetworkCredential credential)? GetCredential(Uri uriPrefix, string authType)
{
Debug.Assert(uriPrefix != null);
Debug.Assert(authType != null);

if (_cache == null)
{
return null;
}

CredentialCacheHelper.TryGetCredential(_cache, uriPrefix, authType, out Uri? mostSpecificMatchUri, out NetworkCredential? mostSpecificMatch);

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"Returning {(mostSpecificMatch == null ? "null" : "(" + mostSpecificMatch.UserName + ":" + mostSpecificMatch.Domain + ")")}");

return mostSpecificMatch == null ? null : (mostSpecificMatchUri!, mostSpecificMatch!);
}
}
}
Loading

0 comments on commit 000c7c8

Please sign in to comment.