Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add limited support for LocalCertificateSelectionCallback for QUIC #70716

Merged
merged 4 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,36 @@ public static SafeMsQuicConfigurationHandle Create(QuicClientConnectionOptions o

if (options.ClientAuthenticationOptions != null)
{
SslClientAuthenticationOptions clientAuthenticationOptions = options.ClientAuthenticationOptions;

#pragma warning disable SYSLIB0040 // NoEncryption and AllowNoEncryption are obsolete
if (options.ClientAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.NoEncryption)
if (clientAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.NoEncryption)
{
throw new PlatformNotSupportedException(SR.Format(SR.net_quic_ssl_option, nameof(options.ClientAuthenticationOptions.EncryptionPolicy)));
throw new PlatformNotSupportedException(SR.Format(SR.net_quic_ssl_option, nameof(clientAuthenticationOptions.EncryptionPolicy)));
}
#pragma warning restore SYSLIB0040

if (options.ClientAuthenticationOptions.ClientCertificates != null)
if (clientAuthenticationOptions.LocalCertificateSelectionCallback != null)
{
X509Certificate? cert = clientAuthenticationOptions.LocalCertificateSelectionCallback(
options,
clientAuthenticationOptions.TargetHost ?? string.Empty,
clientAuthenticationOptions.ClientCertificates ?? new X509CertificateCollection(),
null,
Array.Empty<string>());

try
{
if (((X509Certificate2)cert).HasPrivateKey)
{
certificate = cert;
}
}
catch { }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why might throw here, and do we need to eat all exceptions or just specific ones?

Also, cert above is nullable, but we're not checking for null?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for server, we run this on msquic thread AFAIK. However, for client this is called from MsQuicConnection constructor. So we can perhaps surface any exception from the callback. (including the cast and perhaps check for PrivateKey.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW it should be OK for the callback return NULL. It is up to the server to decide if no certificate is OK or not. I think SslStram has some tests for that case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that the only way this could throw is if it is not X509Certificate2 instance or it was .Reset()ed, I will remove the try-catch and replace it with explicit checks.

}
else if (clientAuthenticationOptions.ClientCertificates != null)
{
foreach (var cert in options.ClientAuthenticationOptions.ClientCertificates)
foreach (var cert in clientAuthenticationOptions.ClientCertificates)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var => whatever type this is

{
try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,52 +292,6 @@ private async ValueTask WriteAsync<TBuffer>(Action<State, TBuffer> stateSetup, T
{
ThrowIfDisposed();

using CancellationTokenRegistration registration = SetupWriteStartState(isEmpty, cancellationToken);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you have this PR based on this #70716. You may wanna merge it, or rebase on main here.


await WriteAsyncCore<TBuffer>(stateSetup, buffer, isEmpty, endStream).ConfigureAwait(false);

CleanupWriteCompletedState();
}

private unsafe ValueTask WriteAsyncCore<TBuffer>(Action<State, TBuffer> stateSetup, TBuffer buffer, bool isEmpty, bool endStream)
{
if (isEmpty)
{
if (endStream)
{
// Start graceful shutdown sequence if passed in the fin flag and there is an empty buffer.
StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0);
}
return default;
}

stateSetup(_state, buffer);

Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
int status = MsQuicApi.Api.ApiTable->StreamSend(
_state.Handle.QuicHandle,
_state.SendBuffers.Buffers,
(uint)_state.SendBuffers.Count,
endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE,
(void*)IntPtr.Zero);

if (StatusFailed(status))
{
CleanupWriteFailedState();
CleanupSendState(_state);

if (status == QUIC_STATUS_ABORTED)
{
throw ThrowHelper.GetConnectionAbortedException(_state.ConnectionState.AbortErrorCode);
}
ThrowIfFailure(status, "Could not send data to peer.");
}

return _state.SendResettableCompletionSource.GetTypelessValueTask();
}

private CancellationTokenRegistration SetupWriteStartState(bool emptyBuffer, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
lock (_state)
Expand Down Expand Up @@ -369,7 +323,7 @@ private CancellationTokenRegistration SetupWriteStartState(bool emptyBuffer, Can
}

// if token was already cancelled, this would execute synchronously
CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) =>
using CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) =>
{
var state = (State)s!;
bool shouldComplete = false;
Expand Down Expand Up @@ -417,14 +371,11 @@ private CancellationTokenRegistration SetupWriteStartState(bool emptyBuffer, Can

// Change the state in the same lock where we check for final states to prevent coming back from Aborted/ConnectionClosed.
Debug.Assert(_state.SendState != SendState.Pending);
_state.SendState = emptyBuffer ? SendState.Finished : SendState.Pending;
_state.SendState = isEmpty ? SendState.Finished : SendState.Pending;
}

return registration;
}
await WriteAsyncCore<TBuffer>(stateSetup, buffer, isEmpty, endStream).ConfigureAwait(false);

private void CleanupWriteCompletedState()
{
lock (_state)
{
if (_state.SendState == SendState.Finished)
Expand All @@ -434,19 +385,57 @@ private void CleanupWriteCompletedState()
}
}

private void CleanupWriteFailedState()
private unsafe ValueTask WriteAsyncCore<TBuffer>(Action<State, TBuffer> stateSetup, TBuffer buffer, bool isEmpty, bool endStream)
{
lock (_state)
if (isEmpty)
{
if (_state.SendState == SendState.Pending)
if (endStream)
{
_state.SendState = SendState.Finished;
// Start graceful shutdown sequence if passed in the fin flag and there is an empty buffer.
StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, errorCode: 0);
}
return default;
}

stateSetup(_state, buffer);

Debug.Assert(!Monitor.IsEntered(_state), "!Monitor.IsEntered(_state)");
int status = MsQuicApi.Api.ApiTable->StreamSend(
_state.Handle.QuicHandle,
_state.SendBuffers.Buffers,
(uint)_state.SendBuffers.Count,
endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE,
(void*)IntPtr.Zero);

if (StatusFailed(status))
{
lock (_state)
{
if (_state.SendState == SendState.Pending)
{
_state.SendState = SendState.Finished;
}
}

CleanupSendState(_state);

if (status == QUIC_STATUS_ABORTED)
{
throw ThrowHelper.GetConnectionAbortedException(_state.ConnectionState.AbortErrorCode);
}
ThrowIfFailure(status, "Could not send data to peer.");
}

return _state.SendResettableCompletionSource.GetTypelessValueTask();
}

internal override async ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken = default)
{
//
// If MsQuic indicated that some data were received (QUIC_STREAM_EVENT_RECEIVE), we use it to complete the request
// synchronously. Otherwise we setup the request to be completed by the HandleEventReceive handler.
//

ThrowIfDisposed();

if (_state.ReadState == ReadState.Closed)
Expand Down Expand Up @@ -1009,6 +998,13 @@ private static unsafe int NativeCallback(QUIC_HANDLE* stream, void* context, QUI

private static unsafe int HandleEventReceive(State state, ref QUIC_STREAM_EVENT streamEvent)
{
//
// Handle MsQuic QUIC_STREAM_EVENT_RECEIVE event
//
// If there is a pending ReadAsync call, then we complete it. Otherwise we keep a pointer to the received data
// and use it to complete the next ReadAsync operation synchronously.
//

ref var receiveEvent = ref streamEvent.RECEIVE;

if (NetEventSource.Log.IsEnabled())
Expand Down
19 changes: 14 additions & 5 deletions src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ public async Task CertificateCallbackThrowPropagates()
}

[Fact]
public async Task ConnectWithCertificateCallback()
public async Task ConnectWithServerCertificateCallback()
{
X509Certificate2 c1 = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate();
X509Certificate2 c2 = System.Net.Test.Common.Configuration.Certificates.GetClientCertificate(); // This 'wrong' certificate but should be sufficient
Expand Down Expand Up @@ -340,10 +340,12 @@ public async Task ConnectWithCertificateForLoopbackIP_IndicatesExpectedError(str
}

[ConditionalTheory]
[InlineData(true)]
[InlineData(false)]
[InlineData(true, true)]
[InlineData(false, true)]
[InlineData(true, false)]
[InlineData(false, false)]
[ActiveIssue("https://github.com/dotnet/runtime/issues/64944", TestPlatforms.Windows)]
public async Task ConnectWithClientCertificate(bool sendCertificate)
public async Task ConnectWithClientCertificate(bool sendCertificate, bool useClientSelectionCallback)
{
if (PlatformDetection.IsWindows10Version20348OrLower)
{
Expand Down Expand Up @@ -371,7 +373,14 @@ public async Task ConnectWithClientCertificate(bool sendCertificate)

using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, listenerOptions);
QuicClientConnectionOptions clientOptions = CreateQuicClientOptions();
if (sendCertificate)
if (useClientSelectionCallback)
{
clientOptions.ClientAuthenticationOptions.LocalCertificateSelectionCallback = delegate
{
return sendCertificate ? ClientCertificate : null;
};
}
else if (sendCertificate)
{
clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };
}
Expand Down