diff --git a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml index e2980f9d37..978aaa00db 100644 --- a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml +++ b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml @@ -1055,6 +1055,23 @@ GO ]]> + + Returns a name value pair collection of internal properties at the point in time the method is called. + Returns a reference of type of (string, object) items. + + + + Gets a string that contains the version of the instance of SQL Server to which the client is connected. The version of the instance of SQL Server. diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs index 461bee2eaf..8b53ecfdf6 100644 --- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs @@ -555,6 +555,18 @@ public SqlConnection(string connectionString, Microsoft.Data.SqlClient.SqlCreden /// [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public System.Guid ClientConnectionId { get { throw null; } } + + /// + /// for internal test only + /// + [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] + internal string SQLDNSCachingSupportedState { get { throw null; } } + /// + /// for internal test only + /// + [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] + internal string SQLDNSCachingSupportedStateBeforeRedirect { get { throw null; } } + object System.ICloneable.Clone() { throw null; } /// [System.ComponentModel.DefaultValueAttribute("")] @@ -639,6 +651,9 @@ public void Open(SqlConnectionOverrides overrides) { } public void ResetStatistics() { } /// public System.Collections.IDictionary RetrieveStatistics() { throw null; } + + /// + public System.Collections.Generic.IDictionary RetrieveInternalInfo() { throw null; } } /// public enum SqlConnectionOverrides diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs index a4eed7751b..2566f441f0 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs @@ -6,6 +6,7 @@ using Microsoft.Data.SqlClient; using System; using System.Runtime.InteropServices; +using System.Text; namespace Microsoft.Data.SqlClient { @@ -20,6 +21,8 @@ internal static partial class SNINativeMethodWrapper [UnmanagedFunctionPointer(CallingConvention.StdCall)] internal delegate void SqlAsyncCallbackDelegate(IntPtr m_ConsKey, IntPtr pPacket, uint dwError); + internal const int SniIP6AddrStringBufferLength = 48; // from SNI layer + internal static int SniMaxComposedSpnLength { get @@ -162,6 +165,20 @@ private unsafe struct SNI_CLIENT_CONSUMER_INFO public TransparentNetworkResolutionMode transparentNetworkResolution; public int totalTimeout; public bool isAzureSqlServerEndpoint; + public SNI_DNSCache_Info DNSCacheInfo; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + internal struct SNI_DNSCache_Info + { + [MarshalAs(UnmanagedType.LPWStr)] + public string wszCachedFQDN; + [MarshalAs(UnmanagedType.LPWStr)] + public string wszCachedTcpIPv4; + [MarshalAs(UnmanagedType.LPWStr)] + public string wszCachedTcpIPv6; + [MarshalAs(UnmanagedType.LPWStr)] + public string wszCachedTcpPort; } [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] @@ -236,6 +253,15 @@ internal struct SNI_Error [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] private static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out Guid pbQInfo); + [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] + private static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out ushort portNum); + + [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Unicode)] + private static extern uint SNIGetPeerAddrStrWrapper([In] SNIHandle pConn, int bufferSize, StringBuilder addrBuffer, out uint addrLen); + + [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] + private static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out ProviderEnum provNum); + [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] private static extern uint SNIInitialize([In] IntPtr pmo); @@ -248,7 +274,8 @@ private static extern uint SNIOpenWrapper( [MarshalAs(UnmanagedType.LPWStr)] string szConnect, [In] SNIHandle pConn, out IntPtr ppConn, - [MarshalAs(UnmanagedType.Bool)] bool fSync); + [MarshalAs(UnmanagedType.Bool)] bool fSync, + [In] ref SNI_DNSCache_Info pDNSCachedInfo); [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] private static extern IntPtr SNIPacketAllocateWrapper([In] SafeHandle pConn, IOType IOType); @@ -283,22 +310,53 @@ internal static uint SniGetConnectionId(SNIHandle pConn, ref Guid connId) { return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_CONNID, out connId); } + + internal static uint SniGetProviderNumber(SNIHandle pConn, ref ProviderEnum provNum) + { + return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_PROVIDERNUM, out provNum); + } + + internal static uint SniGetConnectionPort(SNIHandle pConn, ref ushort portNum) + { + return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_PEERPORT, out portNum); + } + + internal static uint SniGetConnectionIPString(SNIHandle pConn, ref string connIPStr) + { + UInt32 ret; + uint connIPLen = 0; + + int bufferSize = SniIP6AddrStringBufferLength; + StringBuilder addrBuffer = new StringBuilder(bufferSize); + + ret = SNIGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out connIPLen); + + connIPStr = addrBuffer.ToString(0, Convert.ToInt32(connIPLen)); + + return ret; + } internal static uint SNIInitialize() { return SNIInitialize(IntPtr.Zero); } - internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync) + internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync, SQLDNSInfo cachedDNSInfo) { // initialize consumer info for MARS Sni_Consumer_Info native_consumerInfo = new Sni_Consumer_Info(); MarshalConsumerInfo(consumerInfo, ref native_consumerInfo); - return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync); + SNI_DNSCache_Info native_cachedDNSInfo = new SNI_DNSCache_Info(); + native_cachedDNSInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; + native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; + native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; + native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + + return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ref native_cachedDNSInfo); } - internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel) + internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel, SQLDNSInfo cachedDNSInfo) { fixed (byte* pin_instanceName = &instanceName[0]) { @@ -321,6 +379,11 @@ internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string cons clientConsumerInfo.totalTimeout = SniOpenTimeOut; clientConsumerInfo.isAzureSqlServerEndpoint = ADP.IsAzureSqlServerEndpoint(constring); + clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + if (spnBuffer != null) { fixed (byte* pin_spnBuffer = &spnBuffer[0]) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index c31298fb12..4328cfeebc 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -147,6 +147,9 @@ Microsoft\Data\SqlTypes\SqlTypeWorkarounds.cs + + Microsoft\Data\SqlClient\SQLFallbackDNSCache.cs + diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index dd26939d29..eaeedf8f20 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -263,8 +263,10 @@ public uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync) /// Asynchronous connection /// Attempt parallel connects /// + /// Used for DNS Cache + /// Used for DNS Cache /// SNI handle - public SNIHandle CreateConnectionHandle(object callbackObject, string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity) + public SNIHandle CreateConnectionHandle(object callbackObject, string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { instanceName = new byte[1]; @@ -291,7 +293,7 @@ public SNIHandle CreateConnectionHandle(object callbackObject, string fullServer case DataSource.Protocol.Admin: case DataSource.Protocol.None: // default to using tcp if no protocol is provided case DataSource.Protocol.TCP: - sniHandle = CreateTcpHandle(details, timerExpire, callbackObject, parallel); + sniHandle = CreateTcpHandle(details, timerExpire, callbackObject, parallel, cachedFQDN, ref pendingDNSInfo); break; case DataSource.Protocol.NP: sniHandle = CreateNpHandle(details, timerExpire, callbackObject, parallel); @@ -373,8 +375,10 @@ private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrIns /// Timer expiration /// Asynchronous I/O callback object /// Should MultiSubnetFailover be used + /// Key for DNS Cache + /// Used for DNS Cache /// SNITCPHandle - private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, object callbackObject, bool parallel) + private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, object callbackObject, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { // TCP Format: // tcp:\ @@ -412,7 +416,7 @@ private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, objec port = isAdminConnection ? DefaultSqlServerDacPort : DefaultSqlServerPort; } - return new SNITCPHandle(hostName, port, timerExpire, callbackObject, parallel); + return new SNITCPHandle(hostName, port, timerExpire, callbackObject, parallel, cachedFQDN, ref pendingDNSInfo); } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index e83e63882a..9099427e7f 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -116,12 +116,17 @@ public override int ProtocolVersion /// Connection timer expiration /// Callback object /// Parallel executions - public SNITCPHandle(string serverName, int port, long timerExpire, object callbackObject, bool parallel) + /// Key for DNS Cache + /// Used for DNS Cache + public SNITCPHandle(string serverName, int port, long timerExpire, object callbackObject, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { _callbackObject = callbackObject; _targetServer = serverName; _sendSync = new object(); + SQLDNSInfo cachedDNSInfo; + bool hasCachedDNSInfo = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo); + try { TimeSpan ts = default(TimeSpan); @@ -135,33 +140,71 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts; } - Task connectTask; - if (parallel) - { - Task serverAddrTask = Dns.GetHostAddressesAsync(serverName); - serverAddrTask.Wait(ts); - IPAddress[] serverAddresses = serverAddrTask.Result; + bool reportError = true; - if (serverAddresses.Length > MaxParallelIpAddresses) + // We will always first try to connect with serverName as before and let the DNS server to resolve the serverName. + // If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if + // IPv4 fails. The exceptions will be throw to upper level and be handled as before. + try + { + if (parallel) { - // Fail if above 64 to match legacy behavior - ReportTcpSNIError(0, SNICommon.MultiSubnetFailoverWithMoreThan64IPs, string.Empty); - return; + _socket = TryConnectParallel(serverName, port, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); } - - connectTask = ParallelConnectAsync(serverAddresses, port); - - if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(ts))) + else { - ReportTcpSNIError(0, SNICommon.ConnOpenFailedError, string.Empty); - return; + _socket = Connect(serverName, port, ts, isInfiniteTimeOut, cachedFQDN, ref pendingDNSInfo); } - - _socket = connectTask.Result; } - else + catch (Exception ex) { - _socket = Connect(serverName, port, ts, isInfiniteTimeOut); + // Retry with cached IP address + if (ex is SocketException || ex is ArgumentException || ex is AggregateException) + { + if (hasCachedDNSInfo == false) + { + throw; + } + else + { + int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port); + + try + { + if (parallel) + { + _socket = TryConnectParallel(cachedDNSInfo.AddrIPv4, portRetry, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); + } + else + { + _socket = Connect(cachedDNSInfo.AddrIPv4, portRetry, ts, isInfiniteTimeOut, cachedFQDN, ref pendingDNSInfo); + } + } + catch(Exception exRetry) + { + if (exRetry is SocketException || exRetry is ArgumentNullException + || exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException) + { + if (parallel) + { + _socket = TryConnectParallel(cachedDNSInfo.AddrIPv6, portRetry, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); + } + else + { + _socket = Connect(cachedDNSInfo.AddrIPv6, portRetry, ts, isInfiniteTimeOut, cachedFQDN, ref pendingDNSInfo); + } + } + else + { + throw; + } + } + } + } + else + { + throw; + } } if (_socket == null || !_socket.Connected) @@ -171,7 +214,11 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba _socket.Dispose(); _socket = null; } - ReportTcpSNIError(0, SNICommon.ConnOpenFailedError, string.Empty); + + if (reportError) + { + ReportTcpSNIError(0, SNICommon.ConnOpenFailedError, string.Empty); + } return; } @@ -196,9 +243,70 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba _status = TdsEnums.SNI_SUCCESS; } - private static Socket Connect(string serverName, int port, TimeSpan timeout, bool isInfiniteTimeout) + // Connect to server with hostName and port in parellel mode. + // The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point. + // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. + private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool isInfiniteTimeOut, ref bool callerReportError, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) + { + Socket availableSocket = null; + Task connectTask; + + Task serverAddrTask = Dns.GetHostAddressesAsync(hostName); + serverAddrTask.Wait(ts); + IPAddress[] serverAddresses = serverAddrTask.Result; + + if (serverAddresses.Length > MaxParallelIpAddresses) + { + // Fail if above 64 to match legacy behavior + callerReportError = false; + ReportTcpSNIError(0, SNICommon.MultiSubnetFailoverWithMoreThan64IPs, string.Empty); + return availableSocket; + } + + string IPv4String = null; + string IPv6String = null; + + foreach (IPAddress ipAddress in serverAddresses) + { + if (ipAddress.AddressFamily == AddressFamily.InterNetwork) + { + IPv4String = ipAddress.ToString(); + } + else if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6) + { + IPv6String = ipAddress.ToString(); + } + } + + if (IPv4String != null || IPv6String != null) + { + pendingDNSInfo = new SQLDNSInfo(cachedFQDN, IPv4String, IPv6String, port.ToString()); + } + + connectTask = ParallelConnectAsync(serverAddresses, port); + + if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(ts))) + { + callerReportError = false; + ReportTcpSNIError(0, SNICommon.ConnOpenFailedError, string.Empty); + return availableSocket; + } + + availableSocket = connectTask.Result; + return availableSocket; + + } + + // Connect to server with hostName and port. + // The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point. + // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. + private static Socket Connect(string serverName, int port, TimeSpan timeout, bool isInfiniteTimeout, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName); + + string IPv4String = null; + string IPv6String = null; + IPAddress serverIPv4 = null; IPAddress serverIPv6 = null; foreach (IPAddress ipAddress in ipAddresses) @@ -206,15 +314,22 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo if (ipAddress.AddressFamily == AddressFamily.InterNetwork) { serverIPv4 = ipAddress; + IPv4String = ipAddress.ToString(); } else if (ipAddress.AddressFamily == AddressFamily.InterNetworkV6) { serverIPv6 = ipAddress; + IPv6String = ipAddress.ToString(); } } ipAddresses = new IPAddress[] { serverIPv4, serverIPv6 }; Socket[] sockets = new Socket[2]; + if (IPv4String != null || IPv6String != null) + { + pendingDNSInfo = new SQLDNSInfo(cachedFQDN, IPv4String, IPv6String, port.ToString()); + } + CancellationTokenSource cts = null; void Cancel() diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs index 4fccb2cee8..f871d68d4b 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -548,6 +548,52 @@ public override string Database } } + /// + /// To indicate the IsSupported flag sent by the server for DNS Caching. This property is for internal testing only. + /// + internal string SQLDNSCachingSupportedState + { + get + { + SqlInternalConnectionTds innerConnection = (InnerConnection as SqlInternalConnectionTds); + string result; + + if (null != innerConnection) + { + result = innerConnection.IsSQLDNSCachingSupported ? "true": "false"; + } + else + { + result = "innerConnection is null!"; + } + + return result; + } + } + + /// + /// To indicate the IsSupported flag sent by the server for DNS Caching before redirection. This property is for internal testing only. + /// + internal string SQLDNSCachingSupportedStateBeforeRedirect + { + get + { + SqlInternalConnectionTds innerConnection = (InnerConnection as SqlInternalConnectionTds); + string result; + + if (null != innerConnection) + { + result = innerConnection.IsDNSCachingBeforeRedirectSupported ? "true": "false"; + } + else + { + result = "innerConnection is null!"; + } + + return result; + } + } + /// public override string DataSource { @@ -1970,6 +2016,17 @@ private void UpdateStatistics() Statistics.UpdateStatistics(); } + /// + public IDictionary RetrieveInternalInfo() + { + IDictionary internalDictionary = new Dictionary(); + + internalDictionary.Add("SQLDNSCachingSupportedState", SQLDNSCachingSupportedState); + internalDictionary.Add("SQLDNSCachingSupportedStateBeforeRedirect", SQLDNSCachingSupportedStateBeforeRedirect); + + return internalDictionary; + } + /// object ICloneable.Clone() => new SqlConnection(this); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index c7f796fafe..4525d4c75d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -128,6 +128,61 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper; private readonly SqlAuthenticationProviderManager _sqlAuthenticationProviderManager; + internal bool _cleanSQLDNSCaching = false; + + private bool _serverSupportsDNSCaching = false; + + /// + /// Get or set if SQLDNSCaching is supported by the server. + /// + internal bool IsSQLDNSCachingSupported + { + get + { + return _serverSupportsDNSCaching; + } + set + { + _serverSupportsDNSCaching = value; + } + } + + private bool _SQLDNSRetryEnabled = false; + + /// + /// Get or set if we need retrying with IP received from FeatureExtAck. + /// + internal bool IsSQLDNSRetryEnabled + { + get + { + return _SQLDNSRetryEnabled; + } + set + { + _SQLDNSRetryEnabled = value; + } + } + + private bool _DNSCachingBeforeRedirect = false; + + /// + /// Get or set if the control ring send redirect token and feature ext ack with true for DNSCaching + /// + internal bool IsDNSCachingBeforeRedirectSupported + { + get + { + return _DNSCachingBeforeRedirect; + } + set + { + _DNSCachingBeforeRedirect = value; + } + } + + internal SQLDNSInfo pendingSQLDNSObject = null; + // TCE flags internal byte _tceVersionSupported; @@ -1248,6 +1303,9 @@ private void Login(ServerInfo server, TimeoutTimer timeout, string newPassword, // The GLOBALTRANSACTIONS, DATACLASSIFICATION, TCE, and UTF8 support features are implicitly requested requestedFeatures |= TdsEnums.FeatureExtension.GlobalTransactions | TdsEnums.FeatureExtension.DataClassification | TdsEnums.FeatureExtension.Tce | TdsEnums.FeatureExtension.UTF8Support; + // The SQLDNSCaching feature is implicitly set + requestedFeatures |= TdsEnums.FeatureExtension.SQLDNSCaching; + _parser.TdsLogin(login, requestedFeatures, _recoverySessionData, _fedAuthFeatureExtensionData); } @@ -2376,8 +2434,11 @@ internal void OnFeatureExtAck(int featureId, byte[] data) { if (RoutingInfo != null) { - return; + if (TdsEnums.FEATUREEXT_SQLDNSCACHING != featureId) { + return; + } } + switch (featureId) { case TdsEnums.FEATUREEXT_SRECOVERY: @@ -2564,6 +2625,40 @@ internal void OnFeatureExtAck(int featureId, byte[] data) _parser.DataClassificationVersion = (enabled == 0) ? TdsEnums.DATA_CLASSIFICATION_NOT_ENABLED : supportedDataClassificationVersion; break; } + + case TdsEnums.FEATUREEXT_SQLDNSCACHING: + { + SqlClientEventSource.Log.AdvancedTraceEvent(" {0}, Received feature extension acknowledgement for SQLDNSCACHING", ObjectID); + + if (data.Length < 1) + { + SqlClientEventSource.Log.TraceEvent(" {0}, Unknown token for SQLDNSCACHING", ObjectID); + throw SQL.ParsingError(ParsingErrorState.CorruptedTdsStream); + } + + if (1 == data[0]) { + IsSQLDNSCachingSupported = true; + _cleanSQLDNSCaching = false; + + if (RoutingInfo != null) + { + IsDNSCachingBeforeRedirectSupported = true; + } + } + else { + // we receive the IsSupported whose value is 0 + IsSQLDNSCachingSupported = false; + _cleanSQLDNSCaching = true; + } + + // need to add more steps for phase 2 + // get IPv4 + IPv6 + Port number + // not put them in the DNS cache at this point but need to store them somewhere + // generate pendingSQLDNSObject and turn on IsSQLDNSRetryEnabled flag + + break; + } + default: { // Unknown feature ack @@ -2698,4 +2793,3 @@ internal void SetDerivedNames(string protocol, string serverName) } } } - diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsEnums.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsEnums.cs index eda89f3a42..25dbb5dd76 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsEnums.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsEnums.cs @@ -214,6 +214,7 @@ public enum EnvChangeType : byte public const byte FEATUREEXT_AZURESQLSUPPORT = 0x08; public const byte FEATUREEXT_DATACLASSIFICATION = 0x09; public const byte FEATUREEXT_UTF8SUPPORT = 0x0A; + public const byte FEATUREEXT_SQLDNSCACHING = 0x0B; [Flags] public enum FeatureExtension : uint @@ -226,6 +227,7 @@ public enum FeatureExtension : uint AzureSQLSupport = 1 << (TdsEnums.FEATUREEXT_AZURESQLSUPPORT - 1), DataClassification = 1 << (TdsEnums.FEATUREEXT_DATACLASSIFICATION - 1), UTF8Support = 1 << (TdsEnums.FEATUREEXT_UTF8SUPPORT - 1), + SQLDNSCaching = 1 << (TdsEnums.FEATUREEXT_SQLDNSCACHING - 1) } public const uint UTF8_IN_TDSCOLLATION = 0x4000000; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 507b8cc981..9070b88fa6 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -163,6 +163,9 @@ internal sealed partial class TdsParser /// internal string EnclaveType { get; set; } + internal bool isTcpProtocol { get; set; } + internal string FQDNforDNSCahce { get; set; } + /// /// Get if data classification is enabled by the server. /// @@ -362,6 +365,9 @@ internal void Connect( _connHandler = connHandler; _loginWithFailover = withFailover; + // Clean up IsSQLDNSCachingSupported flag from previous status + _connHandler.IsSQLDNSCachingSupported = false; + uint sniStatus = TdsParserStateObjectFactory.Singleton.SNIStatus; if (sniStatus != TdsEnums.SNI_SUCCESS) @@ -414,9 +420,19 @@ internal void Connect( bool fParallel = _connHandler.ConnectionOptions.MultiSubnetFailover; + FQDNforDNSCahce = serverInfo.ResolvedServerName; + + int commaPos = FQDNforDNSCahce.IndexOf(","); + if (commaPos != -1) + { + FQDNforDNSCahce = FQDNforDNSCahce.Substring(0, commaPos); + } + + _connHandler.pendingSQLDNSObject = null; + // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server _physicalStateObj.CreatePhysicalSNIHandle(serverInfo.ExtendedServerName, ignoreSniOpenTimeout, timerExpire, - out instanceName, ref _sniSpnBuffer, false, true, fParallel, integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated); + out instanceName, ref _sniSpnBuffer, false, true, fParallel, FQDNforDNSCahce, ref _connHandler.pendingSQLDNSObject, integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated); if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { @@ -455,6 +471,13 @@ internal void Connect( uint result = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId); Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId"); + + if (null == _connHandler.pendingSQLDNSObject) + { + // for DNS Caching phase 1 + _physicalStateObj.AssignPendingDNSInfo(serverInfo.UserProtocol, FQDNforDNSCahce, ref _connHandler.pendingSQLDNSObject); + } + SqlClientEventSource.Log.TraceEvent(" Sending prelogin handshake", "SEC"); SendPreLoginHandshake(instanceName, encrypt); @@ -473,7 +496,7 @@ internal void Connect( // On Instance failure re-connect and flush SNI named instance cache. _physicalStateObj.SniContext = SniContext.Snix_Connect; - _physicalStateObj.CreatePhysicalSNIHandle(serverInfo.ExtendedServerName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref _sniSpnBuffer, true, true, fParallel, integratedSecurity); + _physicalStateObj.CreatePhysicalSNIHandle(serverInfo.ExtendedServerName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref _sniSpnBuffer, true, true, fParallel, FQDNforDNSCahce, ref _connHandler.pendingSQLDNSObject, integratedSecurity); if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { @@ -487,6 +510,12 @@ internal void Connect( Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId"); SqlClientEventSource.Log.TraceEvent(" Sending prelogin handshake", "SEC"); + if (null == _connHandler.pendingSQLDNSObject) + { + // for DNS Caching phase 1 + _physicalStateObj.AssignPendingDNSInfo(serverInfo.UserProtocol, FQDNforDNSCahce, ref _connHandler.pendingSQLDNSObject); + } + SendPreLoginHandshake(instanceName, encrypt); status = ConsumePreLoginHandshake(encrypt, trustServerCert, integratedSecurity, out marsCapable, out _connHandler._fedAuthRequired); @@ -918,8 +947,9 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake(bool encrypt, bool trus SslProtocols protocol = (SslProtocols)protocolVersion; string warningMessage = protocol.GetProtocolWarning(); - if(!string.IsNullOrEmpty(warningMessage)) + if (!string.IsNullOrEmpty(warningMessage)) { + // This logs console warning of insecure protocol in use. _logger.LogWarning(_typeName, MethodBase.GetCurrentMethod().Name, warningMessage); } @@ -3099,6 +3129,20 @@ private bool TryProcessFeatureExtAck(TdsParserStateObject stateObj) } } while (featureId != TdsEnums.FEATUREEXT_TERMINATOR); + // Write to DNS Cache or clean up DNS Cache for TCP protocol + bool ret = false; + if (_connHandler._cleanSQLDNSCaching) + { + ret = SQLFallbackDNSCache.Instance.DeleteDNSInfo(FQDNforDNSCahce); + } + + if ( _connHandler.IsSQLDNSCachingSupported && _connHandler.pendingSQLDNSObject != null + && !SQLFallbackDNSCache.Instance.IsDuplicate(_connHandler.pendingSQLDNSObject)) + { + ret = SQLFallbackDNSCache.Instance.AddDNSInfo(_connHandler.pendingSQLDNSObject); + _connHandler.pendingSQLDNSObject = null; + } + // Check if column encryption was on and feature wasn't acknowledged and we aren't going to be routed to another server. if (Connection.RoutingInfo == null && _connHandler.ConnectionOptions.ColumnEncryptionSetting == SqlConnectionColumnEncryptionSetting.Enabled @@ -7818,6 +7862,20 @@ internal int WriteFedAuthFeatureRequest(FederatedAuthenticationFeatureExtensionD return len; } + internal int WriteSQLDNSCachingFeatureRequest(bool write /* if false just calculates the length */) + { + int len = 5; // 1byte = featureID, 4bytes = featureData length + + if (write) + { + // Write Feature ID + _physicalStateObj.WriteByte(TdsEnums.FEATUREEXT_SQLDNSCACHING); + WriteInt(0, _physicalStateObj); // we don't send any data + } + + return len; + } + internal void TdsLogin(SqlLogin rec, TdsEnums.FeatureExtension requestedFeatures, SessionData recoverySessionData, FederatedAuthenticationFeatureExtensionData? fedAuthFeatureExtensionData) { _physicalStateObj.SetTimeoutSeconds(rec.timeout); @@ -7975,6 +8033,11 @@ internal void TdsLogin(SqlLogin rec, TdsEnums.FeatureExtension requestedFeatures length += WriteUTF8SupportFeatureRequest(false); } + if ((requestedFeatures & TdsEnums.FeatureExtension.SQLDNSCaching) != 0) + { + length += WriteSQLDNSCachingFeatureRequest(false); + } + length++; // for terminator } @@ -8237,6 +8300,11 @@ internal void TdsLogin(SqlLogin rec, TdsEnums.FeatureExtension requestedFeatures WriteUTF8SupportFeatureRequest(true); } + if ((requestedFeatures & TdsEnums.FeatureExtension.SQLDNSCaching) != 0) + { + WriteSQLDNSCachingFeatureRequest(true); + } + _physicalStateObj.WriteByte(0xFF); // terminator } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.cs index 0aa77e5cae..921d72a385 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.cs @@ -143,7 +143,8 @@ internal SNIHandle( out byte[] instanceName, bool flushCache, bool fSync, - bool fParallel) + bool fParallel, + SQLDNSInfo cachedDNSInfo) : base(IntPtr.Zero, true) { try @@ -158,18 +159,18 @@ internal SNIHandle( } _status = SNINativeMethodWrapper.SNIOpenSyncEx(myInfo, serverName, ref base.handle, - spnBuffer, instanceName, flushCache, fSync, timeout, fParallel); + spnBuffer, instanceName, flushCache, fSync, timeout, fParallel, cachedDNSInfo); } } // constructs SNI Handle for MARS session - internal SNIHandle(SNINativeMethodWrapper.ConsumerInfo myInfo, SNIHandle parent) : base(IntPtr.Zero, true) + internal SNIHandle(SNINativeMethodWrapper.ConsumerInfo myInfo, SNIHandle parent, SQLDNSInfo cachedDNSInfo) : base(IntPtr.Zero, true) { try { } finally { - _status = SNINativeMethodWrapper.SNIOpenMarsSession(myInfo, parent, ref base.handle, parent._fSync); + _status = SNINativeMethodWrapper.SNIOpenMarsSession(myInfo, parent, ref base.handle, parent._fSync, cachedDNSInfo); } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 6888ad0453..44eb698f48 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -762,7 +762,9 @@ private void ResetCancelAndProcessAttention() } } - internal abstract void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, bool isIntegratedSecurity = false); + internal abstract void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity = false); + + internal abstract void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey, ref SQLDNSInfo pendingDNSInfo); internal abstract uint SniGetConnectionId(ref Guid clientConnectionId); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 6e25589986..cc2430bf24 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -49,9 +49,9 @@ internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async) protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) => SNIProxy.Singleton.PacketGetData(packet.ManagedPacket, _inBuff, ref dataSize); - internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity) + internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity) { - _sessionHandle = SNIProxy.Singleton.CreateConnectionHandle(this, serverName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref spnBuffer, flushCache, async, parallel, isIntegratedSecurity); + _sessionHandle = SNIProxy.Singleton.CreateConnectionHandle(this, serverName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref spnBuffer, flushCache, async, parallel, isIntegratedSecurity, cachedFQDN, ref pendingDNSInfo); if (_sessionHandle == null) { _parser.ProcessSNIError(this); @@ -63,6 +63,12 @@ internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSni } } + // The assignment will be happened right after we resolve DNS in managed SNI layer + internal override void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey, ref SQLDNSInfo pendingDNSInfo) + { + // No-op + } + internal void ReadAsyncCallback(SNIPacket packet, uint error) { ReadAsyncCallback(IntPtr.Zero, PacketHandle.FromManagedPacket(packet), error); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index a38b5524df..a358992399 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -9,6 +9,7 @@ using System.Security.Authentication; using System.Threading.Tasks; using Microsoft.Data.Common; +using System.Net; namespace Microsoft.Data.SqlClient { @@ -61,7 +62,62 @@ protected override void CreateSessionHandle(TdsParserStateObject physicalConnect Debug.Assert(physicalConnection is TdsParserStateObjectNative, "Expected a stateObject of type " + this.GetType()); TdsParserStateObjectNative nativeSNIObject = physicalConnection as TdsParserStateObjectNative; SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async); - _sessionHandle = new SNIHandle(myInfo, nativeSNIObject.Handle); + + SQLDNSInfo cachedDNSInfo; + bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(_parser.FQDNforDNSCahce, out cachedDNSInfo); + + _sessionHandle = new SNIHandle(myInfo, nativeSNIObject.Handle, cachedDNSInfo); + } + + internal override void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey, ref SQLDNSInfo pendingDNSInfo) + { + uint result; + ushort portFromSNI = 0; + string IPStringFromSNI = string.Empty; + IPAddress IPFromSNI; + _parser.isTcpProtocol = false; + SNINativeMethodWrapper.ProviderEnum providerNumber = SNINativeMethodWrapper.ProviderEnum.INVALID_PROV; + + if (string.IsNullOrEmpty(userProtocol)) + { + + result = SNINativeMethodWrapper.SniGetProviderNumber(Handle, ref providerNumber); + Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetProviderNumber"); + _parser.isTcpProtocol = (providerNumber == SNINativeMethodWrapper.ProviderEnum.TCP_PROV); + } + else if (userProtocol == TdsEnums.TCP) + { + _parser.isTcpProtocol = true; + } + + // serverInfo.UserProtocol could be empty + if (_parser.isTcpProtocol) + { + result = SNINativeMethodWrapper.SniGetConnectionPort(Handle, ref portFromSNI); + Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionPort"); + + + result = SNINativeMethodWrapper.SniGetConnectionIPString(Handle, ref IPStringFromSNI); + Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString"); + + pendingDNSInfo = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString()); + + if (IPAddress.TryParse(IPStringFromSNI, out IPFromSNI)) + { + if (System.Net.Sockets.AddressFamily.InterNetwork == IPFromSNI.AddressFamily) + { + pendingDNSInfo.AddrIPv4 = IPStringFromSNI; + } + else if (System.Net.Sockets.AddressFamily.InterNetworkV6 == IPFromSNI.AddressFamily) + { + pendingDNSInfo.AddrIPv6 = IPStringFromSNI; + } + } + } + else + { + pendingDNSInfo = null; + } } private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) @@ -82,7 +138,7 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) return myInfo; } - internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, bool isIntegratedSecurity) + internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity) { // We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer spnBuffer = null; @@ -113,7 +169,10 @@ internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSni } } - _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel); + SQLDNSInfo cachedDNSInfo; + bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo); + + _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, cachedDNSInfo); } protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs index 197e886da2..3adffd9dfe 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs @@ -827,6 +827,9 @@ public static void RegisterColumnEncryptionKeyStoreProviders(System.Collections. public void ResetStatistics() { } /// public System.Collections.IDictionary RetrieveStatistics() { throw null; } + + /// + public System.Collections.Generic.IDictionary RetrieveInternalInfo() { throw null; } } /// public enum SqlConnectionColumnEncryptionSetting diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 858fadfdb4..f90f36784d 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -201,6 +201,9 @@ Microsoft\Data\SqlTypes\SqlTypeWorkarounds.cs + + Microsoft\Data\SqlClient\SQLFallbackDNSCache.cs + diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX64.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX64.cs index edfb5e960f..0cddc32dc1 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX64.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX64.cs @@ -4,6 +4,7 @@ using System; using System.Runtime.InteropServices; +using System.Text; using static Microsoft.Data.SqlClient.SNINativeMethodWrapper; namespace Microsoft.Data.SqlClient @@ -78,6 +79,15 @@ internal static class SNINativeManagedWrapperX64 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] internal static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, ref IntPtr pbQInfo); + [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] + internal static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out ushort portNum); + + [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Unicode)] + internal static extern uint SNIGetPeerAddrStrWrapper([In] SNIHandle pConn, int bufferSize, StringBuilder addrBuffer, out uint addrLen); + + [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] + internal static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out ProviderEnum provNum); + [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIInitialize")] internal static extern uint SNIInitialize([In] IntPtr pmo); @@ -90,7 +100,8 @@ internal static extern uint SNIOpenWrapper( [MarshalAs(UnmanagedType.LPWStr)] string szConnect, [In] SNIHandle pConn, out IntPtr ppConn, - [MarshalAs(UnmanagedType.Bool)] bool fSync); + [MarshalAs(UnmanagedType.Bool)] bool fSync, + [In] ref SNI_DNSCache_Info pDNSCachedInfo); [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] internal static extern IntPtr SNIPacketAllocateWrapper([In] SafeHandle pConn, IOType IOType); diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX86.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX86.cs index 89c9af997b..398ecc4872 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX86.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX86.cs @@ -4,6 +4,7 @@ using System; using System.Runtime.InteropServices; +using System.Text; using static Microsoft.Data.SqlClient.SNINativeMethodWrapper; namespace Microsoft.Data.SqlClient @@ -78,6 +79,15 @@ internal static class SNINativeManagedWrapperX86 [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] internal static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, ref IntPtr pbQInfo); + [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] + internal static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out ushort portNum); + + [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Unicode)] + internal static extern uint SNIGetPeerAddrStrWrapper([In] SNIHandle pConn, int bufferSize, StringBuilder addrBuffer, out uint addrLen); + + [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] + internal static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out ProviderEnum provNum); + [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIInitialize")] internal static extern uint SNIInitialize([In] IntPtr pmo); @@ -90,7 +100,8 @@ internal static extern uint SNIOpenWrapper( [MarshalAs(UnmanagedType.LPWStr)] string szConnect, [In] SNIHandle pConn, out IntPtr ppConn, - [MarshalAs(UnmanagedType.Bool)] bool fSync); + [MarshalAs(UnmanagedType.Bool)] bool fSync, + [In] ref SNI_DNSCache_Info pDNSCachedInfo); [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] internal static extern IntPtr SNIPacketAllocateWrapper([In] SafeHandle pConn, IOType IOType); diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs index fefdeea4b7..66efa587b6 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs @@ -13,6 +13,7 @@ using System.Threading; using Microsoft.Data.Common; using Microsoft.Data.SqlClient; +using System.Text; namespace Microsoft.Data.SqlClient { @@ -50,6 +51,7 @@ internal static class SNINativeMethodWrapper internal const int LocalDBInvalidSqlUserInstanceDllPath = 55; internal const int LocalDBFailedToLoadDll = 56; internal const int LocalDBBadRuntime = 57; + internal const int SniIP6AddrStringBufferLength = 48; // from SNI layer internal static int SniMaxComposedSpnLength { @@ -352,6 +354,20 @@ internal unsafe struct SNI_CLIENT_CONSUMER_INFO public TransparentNetworkResolutionMode transparentNetworkResolution; public int totalTimeout; public bool isAzureSqlServerEndpoint; + public SNI_DNSCache_Info DNSCacheInfo; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + internal struct SNI_DNSCache_Info + { + [MarshalAs(UnmanagedType.LPWStr)] + public string wszCachedFQDN; + [MarshalAs(UnmanagedType.LPWStr)] + public string wszCachedTcpIPv4; + [MarshalAs(UnmanagedType.LPWStr)] + public string wszCachedTcpIPv6; + [MarshalAs(UnmanagedType.LPWStr)] + public string wszCachedTcpPort; } [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] @@ -547,6 +563,27 @@ private static uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapp SNINativeManagedWrapperX86.SNIGetInfoWrapper(pConn, QType, ref pbQInfo); } + private static uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out ushort portNum) + { + return s_is64bitProcess ? + SNINativeManagedWrapperX64.SNIGetInfoWrapper(pConn, QType, out portNum) : + SNINativeManagedWrapperX86.SNIGetInfoWrapper(pConn, QType, out portNum); + } + + private static uint SNIGetPeerAddrStrWrapper([In] SNIHandle pConn, int bufferSize, StringBuilder addrBuffer, out uint addrLen) + { + return s_is64bitProcess ? + SNINativeManagedWrapperX64.SNIGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out addrLen) : + SNINativeManagedWrapperX86.SNIGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out addrLen); + } + + private static uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out ProviderEnum provNum) + { + return s_is64bitProcess ? + SNINativeManagedWrapperX64.SNIGetInfoWrapper(pConn, QType, out provNum) : + SNINativeManagedWrapperX86.SNIGetInfoWrapper(pConn, QType, out provNum); + } + private static uint SNIInitialize([In] IntPtr pmo) { return s_is64bitProcess ? @@ -566,11 +603,12 @@ private static uint SNIOpenWrapper( [MarshalAs(UnmanagedType.LPWStr)] string szConnect, [In] SNIHandle pConn, out IntPtr ppConn, - [MarshalAs(UnmanagedType.Bool)] bool fSync) + [MarshalAs(UnmanagedType.Bool)] bool fSync, + [In] ref SNI_DNSCache_Info pDNSCachedInfo) { return s_is64bitProcess ? - SNINativeManagedWrapperX64.SNIOpenWrapper(ref pConsumerInfo, szConnect, pConn, out ppConn, fSync) : - SNINativeManagedWrapperX86.SNIOpenWrapper(ref pConsumerInfo, szConnect, pConn, out ppConn, fSync); + SNINativeManagedWrapperX64.SNIOpenWrapper(ref pConsumerInfo, szConnect, pConn, out ppConn, fSync, ref pDNSCachedInfo) : + SNINativeManagedWrapperX86.SNIOpenWrapper(ref pConsumerInfo, szConnect, pConn, out ppConn, fSync, ref pDNSCachedInfo); } private static IntPtr SNIPacketAllocateWrapper([In] SafeHandle pConn, IOType IOType) @@ -687,22 +725,55 @@ internal static uint SniGetConnectionId(SNIHandle pConn, ref Guid connId) { return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_CONNID, out connId); } + + internal static uint SniGetProviderNumber(SNIHandle pConn, ref ProviderEnum provNum) + { + return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_PROVIDERNUM, out provNum); + } + + internal static uint SniGetConnectionPort(SNIHandle pConn, ref ushort portNum) + { + return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_PEERPORT, out portNum); + } + + internal static uint SniGetConnectionIPString(SNIHandle pConn, ref string connIPStr) + { + UInt32 ret; + uint ERROR_SUCCESS = 0; + uint connIPLen = 0; + + int bufferSize = SniIP6AddrStringBufferLength; + StringBuilder addrBuffer = new StringBuilder(bufferSize); + + ret = SNIGetPeerAddrStrWrapper(pConn, bufferSize, addrBuffer, out connIPLen); + Debug.Assert(ret == ERROR_SUCCESS, "SNIGetPeerAddrStrWrapper fail"); + + connIPStr = addrBuffer.ToString(0, Convert.ToInt32(connIPLen)); + + return ret; + } internal static uint SNIInitialize() { return SNIInitialize(IntPtr.Zero); } - internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync) + internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync, SQLDNSInfo cachedDNSInfo) { // initialize consumer info for MARS Sni_Consumer_Info native_consumerInfo = new Sni_Consumer_Info(); MarshalConsumerInfo(consumerInfo, ref native_consumerInfo); - return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync); + SNI_DNSCache_Info native_cachedDNSInfo = new SNI_DNSCache_Info(); + native_cachedDNSInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; + native_cachedDNSInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; + native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; + native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + + return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ref native_cachedDNSInfo); } - internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel, Int32 transparentNetworkResolutionStateNo, Int32 totalTimeout, Boolean isAzureSqlServerEndpoint) + internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel, Int32 transparentNetworkResolutionStateNo, Int32 totalTimeout, Boolean isAzureSqlServerEndpoint, SQLDNSInfo cachedDNSInfo) { fixed (byte* pin_instanceName = &instanceName[0]) { @@ -737,6 +808,11 @@ internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string cons }; clientConsumerInfo.totalTimeout = totalTimeout; + clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6; + clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port; + if (spnBuffer != null) { fixed (byte* pin_spnBuffer = &spnBuffer[0]) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs index ba55d29380..d48d907e1c 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -725,6 +725,54 @@ override public string Database } } + /// + /// To indicate the IsSupported flag sent by the server for DNS Caching. This property is for internal testing only. + /// + [DesignerSerializationVisibility(DesignerSerializationVisibility.Hidden)] + internal string SQLDNSCachingSupportedState + { + get + { + SqlInternalConnectionTds innerConnection = (InnerConnection as SqlInternalConnectionTds); + string result; + + if (null != innerConnection) + { + result = innerConnection.IsSQLDNSCachingSupported ? "true": "false"; + } + else + { + result = "innerConnection is null!"; + } + + return result; + } + } + + /// + /// To indicate the IsSupported flag sent by the server for DNS Caching before redirection. This property is for internal testing only. + /// + [DesignerSerializationVisibility(DesignerSerializationVisibility.Hidden)] + internal string SQLDNSCachingSupportedStateBeforeRedirect + { + get + { + SqlInternalConnectionTds innerConnection = (InnerConnection as SqlInternalConnectionTds); + string result; + + if (null != innerConnection) + { + result = innerConnection.IsDNSCachingBeforeRedirectSupported ? "true": "false"; + } + else + { + result = "innerConnection is null!"; + } + + return result; + } + } + /// [ Browsable(true), @@ -2675,6 +2723,17 @@ private void UpdateStatistics() Statistics.UpdateStatistics(); } + /// + public IDictionary RetrieveInternalInfo() + { + IDictionary internalDictionary = new Dictionary(); + + internalDictionary.Add("SQLDNSCachingSupportedState", SQLDNSCachingSupportedState); + internalDictionary.Add("SQLDNSCachingSupportedStateBeforeRedirect", SQLDNSCachingSupportedStateBeforeRedirect); + + return internalDictionary; + } + // // UDT SUPPORT // diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index b3bcb0bb6e..0da06ba54d 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -142,6 +142,61 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa ClientCertificateRetrievalCallback _clientCallback; SqlClientOriginalNetworkAddressInfo _originalNetworkAddressInfo; + internal bool _cleanSQLDNSCaching = false; + + private bool _serverSupportsDNSCaching = false; + + /// + /// Get or set if SQLDNSCaching FeatureExtAck is supported by the server. + /// + internal bool IsSQLDNSCachingSupported + { + get + { + return _serverSupportsDNSCaching; + } + set + { + _serverSupportsDNSCaching = value; + } + } + + private bool _SQLDNSRetryEnabled = false; + + /// + /// Get or set if we need retrying with IP received from FeatureExtAck. + /// + internal bool IsSQLDNSRetryEnabled + { + get + { + return _SQLDNSRetryEnabled; + } + set + { + _SQLDNSRetryEnabled = value; + } + } + + private bool DNSCachingBeforeRedirect = false; + + /// + /// Get or set if the control ring send redirect token and SQLDNSCaching FeatureExtAck with true + /// + internal bool IsDNSCachingBeforeRedirectSupported + { + get + { + return DNSCachingBeforeRedirect; + } + set + { + DNSCachingBeforeRedirect = value; + } + } + + internal SQLDNSInfo pendingSQLDNSObject = null; + // TCE flags internal byte _tceVersionSupported; @@ -1530,6 +1585,9 @@ private void Login(ServerInfo server, TimeoutTimer timeout, string newPassword, requestedFeatures |= TdsEnums.FeatureExtension.AzureSQLSupport; } + // The SQLDNSCaching feature is implicitly set + requestedFeatures |= TdsEnums.FeatureExtension.SQLDNSCaching; + _parser.TdsLogin(login, requestedFeatures, _recoverySessionData, _fedAuthFeatureExtensionData, _originalNetworkAddressInfo); } @@ -2815,8 +2873,11 @@ internal void OnFeatureExtAck(int featureId, byte[] data) { if (_routingInfo != null) { - return; + if (TdsEnums.FEATUREEXT_SQLDNSCACHING != featureId) { + return; + } } + switch (featureId) { case TdsEnums.FEATUREEXT_SRECOVERY: @@ -3030,6 +3091,40 @@ internal void OnFeatureExtAck(int featureId, byte[] data) break; } + case TdsEnums.FEATUREEXT_SQLDNSCACHING: + { + SqlClientEventSource.Log.AdvancedTraceEvent(" {0}, Received feature extension acknowledgement for SQLDNSCACHING", ObjectID); + + if (data.Length < 1) + { + SqlClientEventSource.Log.TraceEvent(" {0}, Unknown token for SQLDNSCACHING", ObjectID); + throw SQL.ParsingError(ParsingErrorState.CorruptedTdsStream); + } + + if (1 == data[0]) { + IsSQLDNSCachingSupported = true; + _cleanSQLDNSCaching = false; + + if (_routingInfo != null) + { + IsDNSCachingBeforeRedirectSupported = true; + } + } + else { + // we receive the IsSupported whose value is 0 + IsSQLDNSCachingSupported = false; + _cleanSQLDNSCaching = true; + } + + // need to add more steps for phrase 2 + // get IPv4 + IPv6 + Port number + // not put them in the DNS cache at this point but need to store them somewhere + + // generate pendingSQLDNSObject and turn on IsSQLDNSRetryEnabled flag + + break; + } + default: { // Unknown feature ack @@ -3164,4 +3259,3 @@ internal void SetDerivedNames(string protocol, string serverName) } } } - diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsEnums.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsEnums.cs index 92e089b298..eaba0b9cc8 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsEnums.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsEnums.cs @@ -206,6 +206,7 @@ internal static class TdsEnums public const byte FEATUREEXT_AZURESQLSUPPORT = 0x08; public const byte FEATUREEXT_DATACLASSIFICATION = 0x09; public const byte FEATUREEXT_UTF8SUPPORT = 0x0A; + public const byte FEATUREEXT_SQLDNSCACHING = 0x0B; [Flags] public enum FeatureExtension : uint @@ -217,7 +218,8 @@ public enum FeatureExtension : uint GlobalTransactions = 1 << (TdsEnums.FEATUREEXT_GLOBALTRANSACTIONS - 1), AzureSQLSupport = 1 << (TdsEnums.FEATUREEXT_AZURESQLSUPPORT - 1), DataClassification = 1 << (TdsEnums.FEATUREEXT_DATACLASSIFICATION - 1), - UTF8Support = 1 << (TdsEnums.FEATUREEXT_UTF8SUPPORT - 1), + UTF8Support = 1 << (TdsEnums.FEATUREEXT_UTF8SUPPORT - 1), + SQLDNSCaching = 1 << (TdsEnums.FEATUREEXT_SQLDNSCACHING - 1) } public const uint UTF8_IN_TDSCOLLATION = 0x4000000; diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 859c8bd3bb..4f5fd8a2e1 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -17,6 +17,7 @@ using System.Threading; using System.Threading.Tasks; using System.Xml; +using System.Net; using Microsoft.Data.Common; using Microsoft.Data.Sql; using Microsoft.Data.SqlClient.DataClassification; @@ -283,6 +284,10 @@ internal bool IsColumnEncryptionSupported /// internal string EnclaveType { get; set; } + internal bool isTcpProtocol { get; set; } + + internal string FQDNforDNSCahce { get; set; } + /// /// Get if data classification is enabled by the server. /// @@ -499,6 +504,9 @@ internal void Connect(ServerInfo serverInfo, _connHandler = connHandler; _loginWithFailover = withFailover; + // Clean up IsSQLDNSCachingSupported flag from previous status + _connHandler.IsSQLDNSCachingSupported = false; + UInt32 sniStatus = SNILoadHandle.SingletonInstance.SNIStatus; if (sniStatus != TdsEnums.SNI_SUCCESS) { @@ -567,8 +575,16 @@ internal void Connect(ServerInfo serverInfo, int totalTimeout = _connHandler.ConnectionOptions.ConnectTimeout; + FQDNforDNSCahce = serverInfo.ResolvedServerName; + + int commaPos = FQDNforDNSCahce.IndexOf(","); + if (commaPos != -1) + { + FQDNforDNSCahce = FQDNforDNSCahce.Substring(0, commaPos); + } + _physicalStateObj.CreatePhysicalSNIHandle(serverInfo.ExtendedServerName, ignoreSniOpenTimeout, timerExpire, - out instanceName, _sniSpnBuffer, false, true, fParallel, transparentNetworkResolutionState, totalTimeout); + out instanceName, _sniSpnBuffer, false, true, fParallel, transparentNetworkResolutionState, totalTimeout, FQDNforDNSCahce); if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { @@ -607,6 +623,9 @@ internal void Connect(ServerInfo serverInfo, UInt32 result = SNINativeMethodWrapper.SniGetConnectionId(_physicalStateObj.Handle, ref _connHandler._clientConnectionId); Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId"); + + // for DNS Caching phase 1 + AssignPendingDNSInfo(serverInfo.UserProtocol, FQDNforDNSCahce); // UNDONE - send "" for instance now, need to fix later SqlClientEventSource.Log.TraceEvent(" Sending prelogin handshake", "SEC"); @@ -628,7 +647,7 @@ internal void Connect(ServerInfo serverInfo, // On Instance failure re-connect and flush SNI named instance cache. _physicalStateObj.SniContext = SniContext.Snix_Connect; - _physicalStateObj.CreatePhysicalSNIHandle(serverInfo.ExtendedServerName, ignoreSniOpenTimeout, timerExpire, out instanceName, _sniSpnBuffer, true, true, fParallel, transparentNetworkResolutionState, totalTimeout); + _physicalStateObj.CreatePhysicalSNIHandle(serverInfo.ExtendedServerName, ignoreSniOpenTimeout, timerExpire, out instanceName, _sniSpnBuffer, true, true, fParallel, transparentNetworkResolutionState, totalTimeout, serverInfo.ResolvedServerName); if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { @@ -642,6 +661,9 @@ internal void Connect(ServerInfo serverInfo, Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId"); SqlClientEventSource.Log.TraceEvent(" Sending prelogin handshake", "SEC"); + // for DNS Caching phase 1 + AssignPendingDNSInfo(serverInfo.UserProtocol, FQDNforDNSCahce); + SendPreLoginHandshake(instanceName, encrypt, !string.IsNullOrEmpty(certificate), useOriginalAddressInfo); status = ConsumePreLoginHandshake(authType, encrypt, trustServerCert, integratedSecurity, serverCallback, clientCallback, out marsCapable, out _connHandler._fedAuthRequired); @@ -669,6 +691,60 @@ internal void Connect(ServerInfo serverInfo, return; } + // Retrieve the IP and port number from native SNI for TCP protocol. The IP information is stored temporarily in the + // pendingSQLDNSObject but not in the DNS Cache at this point. We only add items to the DNS Cache after we receive the + // IsSupported flag as true in the feature ext ack from server. + internal void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey) + { + UInt32 result; + ushort portFromSNI = 0; + string IPStringFromSNI = string.Empty; + IPAddress IPFromSNI; + isTcpProtocol = false; + SNINativeMethodWrapper.ProviderEnum providerNumber = SNINativeMethodWrapper.ProviderEnum.INVALID_PROV; + + if (string.IsNullOrEmpty(userProtocol)) + { + + result = SNINativeMethodWrapper.SniGetProviderNumber(_physicalStateObj.Handle, ref providerNumber); + Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetProviderNumber"); + isTcpProtocol = (providerNumber == SNINativeMethodWrapper.ProviderEnum.TCP_PROV); + } + else if (userProtocol == TdsEnums.TCP) + { + isTcpProtocol = true; + } + + // serverInfo.UserProtocol could be empty + if (isTcpProtocol) + { + result = SNINativeMethodWrapper.SniGetConnectionPort(_physicalStateObj.Handle, ref portFromSNI); + Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionPort"); + + + result = SNINativeMethodWrapper.SniGetConnectionIPString(_physicalStateObj.Handle, ref IPStringFromSNI); + Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString"); + + _connHandler.pendingSQLDNSObject = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString()); + + if (IPAddress.TryParse(IPStringFromSNI, out IPFromSNI)) + { + if (System.Net.Sockets.AddressFamily.InterNetwork == IPFromSNI.AddressFamily) + { + _connHandler.pendingSQLDNSObject.AddrIPv4 = IPStringFromSNI; + } + else if (System.Net.Sockets.AddressFamily.InterNetworkV6 == IPFromSNI.AddressFamily) + { + _connHandler.pendingSQLDNSObject.AddrIPv6 = IPStringFromSNI; + } + } + } + else + { + _connHandler.pendingSQLDNSObject = null; + } + } + internal void RemoveEncryption() { Debug.Assert((_encryptionOption & EncryptionOptions.OPTIONS_MASK) == EncryptionOptions.LOGIN, "Invalid encryption option state"); @@ -1220,6 +1296,7 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake(SqlAuthenticationMethod string warningMessage = SslProtocolsHelper.GetProtocolWarning(protocolVersion); if (!string.IsNullOrEmpty(warningMessage)) { + // This logs console warning of insecure protocol in use. _logger.LogWarning(_typeName, MethodBase.GetCurrentMethod().Name, warningMessage); } @@ -3474,6 +3551,20 @@ private bool TryProcessFeatureExtAck(TdsParserStateObject stateObj) } } while (featureId != TdsEnums.FEATUREEXT_TERMINATOR); + // Write to DNS Cache or clean up DNS Cache for TCP protocol + bool ret = false; + if (_connHandler._cleanSQLDNSCaching) + { + ret = SQLFallbackDNSCache.Instance.DeleteDNSInfo(FQDNforDNSCahce); + } + + if ( _connHandler.IsSQLDNSCachingSupported && _connHandler.pendingSQLDNSObject != null + && !SQLFallbackDNSCache.Instance.IsDuplicate(_connHandler.pendingSQLDNSObject)) + { + ret = SQLFallbackDNSCache.Instance.AddDNSInfo(_connHandler.pendingSQLDNSObject); + _connHandler.pendingSQLDNSObject = null; + } + // Check if column encryption was on and feature wasn't acknowledged and we aren't going to be routed to another server. if (this.Connection.RoutingInfo == null && _connHandler.ConnectionOptions.ColumnEncryptionSetting == SqlConnectionColumnEncryptionSetting.Enabled @@ -8552,6 +8643,20 @@ internal int WriteFedAuthFeatureRequest(FederatedAuthenticationFeatureExtensionD return len; } + internal int WriteSQLDNSCachingFeatureRequest(bool write /* if false just calculates the length */) + { + int len = 5; // 1byte = featureID, 4bytes = featureData length + + if (write) + { + // Write Feature ID + _physicalStateObj.WriteByte(TdsEnums.FEATUREEXT_SQLDNSCACHING); + WriteInt(0, _physicalStateObj); // we don't send any data + } + + return len; + } + internal void TdsLogin(SqlLogin rec, TdsEnums.FeatureExtension requestedFeatures, SessionData recoverySessionData, @@ -8739,6 +8844,12 @@ internal void TdsLogin(SqlLogin rec, { length += WriteUTF8SupportFeatureRequest(false); } + + if ((requestedFeatures & TdsEnums.FeatureExtension.SQLDNSCaching) != 0) + { + length += WriteSQLDNSCachingFeatureRequest(false); + } + length++; // for terminator } } @@ -9010,6 +9121,12 @@ internal void TdsLogin(SqlLogin rec, { WriteUTF8SupportFeatureRequest(true); } + + if ((requestedFeatures & TdsEnums.FeatureExtension.SQLDNSCaching) != 0) + { + WriteSQLDNSCachingFeatureRequest(true); + } + _physicalStateObj.WriteByte(0xFF); // terminator } } // try diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.cs index e351d36d61..30e874995c 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserSafeHandles.cs @@ -149,7 +149,8 @@ internal SNIHandle( bool fSync, bool fParallel, TransparentNetworkResolutionState transparentNetworkResolutionState, - int totalTimeout) + int totalTimeout, + SQLDNSInfo cachedDNSInfo) : base(IntPtr.Zero, true) { @@ -171,19 +172,19 @@ internal SNIHandle( int transparentNetworkResolutionStateNo = (int)transparentNetworkResolutionState; _status = SNINativeMethodWrapper.SNIOpenSyncEx(myInfo, serverName, ref base.handle, spnBuffer, instanceName, flushCache, fSync, timeout, fParallel, transparentNetworkResolutionStateNo, totalTimeout, - ADP.IsAzureSqlServerEndpoint(serverName)); + ADP.IsAzureSqlServerEndpoint(serverName), cachedDNSInfo); } } // constructs SNI Handle for MARS session - internal SNIHandle(SNINativeMethodWrapper.ConsumerInfo myInfo, SNIHandle parent) : base(IntPtr.Zero, true) + internal SNIHandle(SNINativeMethodWrapper.ConsumerInfo myInfo, SNIHandle parent, SQLDNSInfo cachedDNSInfo) : base(IntPtr.Zero, true) { RuntimeHelpers.PrepareConstrainedRegions(); try { } finally { - _status = SNINativeMethodWrapper.SNIOpenMarsSession(myInfo, parent, ref base.handle, parent._fSync); + _status = SNINativeMethodWrapper.SNIOpenMarsSession(myInfo, parent, ref base.handle, parent._fSync, cachedDNSInfo); } } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 63f0f7d68b..09305c3cf4 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -294,7 +294,11 @@ internal TdsParserStateObject(TdsParser parser, SNIHandle physicalConnection, bo SetPacketSize(_parser._physicalStateObj._outBuff.Length); SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async); - _sessionHandle = new SNIHandle(myInfo, physicalConnection); + + SQLDNSInfo cachedDNSInfo; + bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(_parser.FQDNforDNSCahce, out cachedDNSInfo); + + _sessionHandle = new SNIHandle(myInfo, physicalConnection, cachedDNSInfo); if (_sessionHandle.Status != TdsEnums.SNI_SUCCESS) { AddError(parser.ProcessSNIError(this)); @@ -820,7 +824,7 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) return myInfo; } - internal void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, byte[] spnBuffer, bool flushCache, bool async, bool fParallel, TransparentNetworkResolutionState transparentNetworkResolutionState, int totalTimeout) + internal void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, byte[] spnBuffer, bool flushCache, bool async, bool fParallel, TransparentNetworkResolutionState transparentNetworkResolutionState, int totalTimeout, string cachedFQDN) { SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async); @@ -842,7 +846,13 @@ internal void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeo timeout = 0; } } - _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, transparentNetworkResolutionState, totalTimeout); + + // serverName : serverInfo.ExtendedServerName + // may not use this serverName as key + SQLDNSInfo cachedDNSInfo; + bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo); + + _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, transparentNetworkResolutionState, totalTimeout, cachedDNSInfo); } internal bool Deactivate() diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs new file mode 100644 index 0000000000..e18b61cee4 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SQLFallbackDNSCache.cs @@ -0,0 +1,86 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Concurrent; + +namespace Microsoft.Data.SqlClient +{ + internal class SQLFallbackDNSCache + { + private static readonly SQLFallbackDNSCache _SQLFallbackDNSCache = new SQLFallbackDNSCache(); + private static readonly int initialCapacity = 101; // give some prime number here according to MSDN docs. It will be resized if reached capacity. + private ConcurrentDictionary DNSInfoCache; + + // singleton instance + public static SQLFallbackDNSCache Instance { get { return _SQLFallbackDNSCache; } } + + private SQLFallbackDNSCache() + { + int level = 4 * Environment.ProcessorCount; + DNSInfoCache = new ConcurrentDictionary(concurrencyLevel: level, + capacity: initialCapacity, + comparer: StringComparer.OrdinalIgnoreCase); + } + + internal bool AddDNSInfo(SQLDNSInfo item) + { + if (null != item) + { + if (DNSInfoCache.ContainsKey(item.FQDN)) + { + + DeleteDNSInfo(item.FQDN); + } + + return DNSInfoCache.TryAdd(item.FQDN, item); + } + + return false; + } + + internal bool DeleteDNSInfo(string FQDN) + { + SQLDNSInfo value; + return DNSInfoCache.TryRemove(FQDN, out value); + } + + internal bool GetDNSInfo(string FQDN, out SQLDNSInfo result) + { + return DNSInfoCache.TryGetValue(FQDN, out result); + } + + internal bool IsDuplicate(SQLDNSInfo newItem) + { + if (null != newItem) + { + SQLDNSInfo oldItem; + if (GetDNSInfo(newItem.FQDN, out oldItem)) + { + return (newItem.AddrIPv4 == oldItem.AddrIPv4 && + newItem.AddrIPv6 == oldItem.AddrIPv6 && + newItem.Port == oldItem.Port); + } + } + + return false; + } + } + + internal class SQLDNSInfo + { + public string FQDN { get; set; } + public string AddrIPv4 { get; set; } + public string AddrIPv6 { get; set; } + public string Port { get; set; } + + internal SQLDNSInfo(string FQDN, string ipv4, string ipv6, string port) + { + this.FQDN = FQDN; + AddrIPv4 = ipv4; + AddrIPv6 = ipv6; + Port = port; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientLogger.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientLogger.cs index 68afb07e49..5150862224 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientLogger.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientLogger.cs @@ -32,7 +32,6 @@ public void LogWarning(string type, string method, string message) /// public void LogError(string type, string method, string message) { - Console.Out.WriteLine(message); SqlClientEventSource.Log.TraceEvent("{3}", type, method, LogLevel.Error, message); } diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionTest.cs index 244e008f96..f2870aae1e 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionTest.cs @@ -4,12 +4,19 @@ using System; using System.Data; +using System.Collections.Generic; using Xunit; namespace Microsoft.Data.SqlClient.Tests { public partial class SqlConnectionTest { + private static readonly string[] s_retrieveInternalInfoKeys = + { + "SQLDNSCachingSupportedState", + "SQLDNSCachingSupportedStateBeforeRedirect" + }; + [Fact] public void Constructor1() { @@ -1212,5 +1219,47 @@ public void ServerVersion_Connection_Closed() Assert.NotNull(ex.Message); } } + + [Fact] + public void RetrieveInternalInfo_Success() + { + SqlConnection cn = new SqlConnection(); + IDictionary d = cn.RetrieveInternalInfo(); + + Assert.NotNull(d); + } + + [Fact] + public void RetrieveInternalInfo_ExpectedKeysInDictionary_Success() + { + SqlConnection cn = new SqlConnection(); + IDictionary d = cn.RetrieveInternalInfo(); + + Assert.NotEmpty(d); + Assert.Equal(s_retrieveInternalInfoKeys.Length, d.Count); + + Assert.NotEmpty(d.Keys); + Assert.Equal(s_retrieveInternalInfoKeys.Length, d.Keys.Count); + + Assert.NotEmpty(d.Values); + Assert.Equal(s_retrieveInternalInfoKeys.Length, d.Values.Count); + + foreach(string key in s_retrieveInternalInfoKeys) + { + Assert.True(d.ContainsKey(key)); + + d.TryGetValue(key, out object value); + Assert.NotNull(value); + Assert.IsType(value); + } + } + + [Fact] + public void RetrieveInternalInfo_UnexpectedKeysInDictionary_Success() + { + SqlConnection cn = new SqlConnection(); + IDictionary d = cn.RetrieveInternalInfo(); + Assert.False(d.ContainsKey("Foo")); + } } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs index 6d0af535e4..1e8fa398c6 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs @@ -43,6 +43,12 @@ public static class DataTestUtility public static readonly bool SupportsFileStream = false; public static readonly bool UseManagedSNIOnWindows = false; + public static readonly string DNSCachingConnString = null; + public static readonly string DNSCachingServerCR = null; // this is for the control ring + public static readonly string DNSCachingServerTR = null; // this is for the tenant ring + public static readonly bool IsDNSCachingSupportedCR = false; // this is for the control ring + public static readonly bool IsDNSCachingSupportedTR = false; // this is for the tenant ring + public const string UdtTestDbName = "UdtTestDb"; public const string AKVKeyName = "TestSqlClientAzureKeyVaultProvider"; private const string ManagedNetworkingAppContextSwitch = "Switch.Microsoft.Data.SqlClient.UseManagedNetworkingOnWindows"; @@ -75,6 +81,11 @@ private class Config public bool SupportsLocalDb = false; public bool SupportsFileStream = false; public bool UseManagedSNIOnWindows = false; + public string DNSCachingConnString = null; + public string DNSCachingServerCR = null; // this is for the control ring + public string DNSCachingServerTR = null; // this is for the tenant ring + public bool IsDNSCachingSupportedCR = false; // this is for the control ring + public bool IsDNSCachingSupportedTR = false; // this is for the tenant ring } static DataTestUtility() @@ -100,6 +111,12 @@ static DataTestUtility() TracingEnabled = c.TracingEnabled; UseManagedSNIOnWindows = c.UseManagedSNIOnWindows; + DNSCachingConnString = c.DNSCachingConnString; + DNSCachingServerCR = c.DNSCachingServerCR; + DNSCachingServerTR = c.DNSCachingServerTR; + IsDNSCachingSupportedCR = c.IsDNSCachingSupportedCR; + IsDNSCachingSupportedTR = c.IsDNSCachingSupportedTR; + if (TracingEnabled) { TraceListener = new DataTestUtility.TraceEventListener(); @@ -264,6 +281,9 @@ public static bool IsSupportedDataClassification() } return true; } + + public static bool IsDNSCachingSetup() => !string.IsNullOrEmpty(DNSCachingConnString); + public static bool IsUdtTestDatabasePresent() => IsDatabasePresent(UdtTestDbName); public static bool AreConnStringsSetup() diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index d0152fca05..34bd1e6ead 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -195,6 +195,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DNSCachingTest/DNSCachingTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DNSCachingTest/DNSCachingTest.cs new file mode 100644 index 0000000000..33460acb8d --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DNSCachingTest/DNSCachingTest.cs @@ -0,0 +1,79 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Reflection; +using Xunit; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests +{ + + public class DNSCachingTest + { + public static Assembly systemData = Assembly.GetAssembly(typeof(SqlConnection)); + public static Type SQLFallbackDNSCacheType = systemData.GetType("Microsoft.Data.SqlClient.SQLFallbackDNSCache"); + public static Type SQLDNSInfoType = systemData.GetType("Microsoft.Data.SqlClient.SQLDNSInfo"); + public static MethodInfo SQLFallbackDNSCacheGetDNSInfo = SQLFallbackDNSCacheType.GetMethod("GetDNSInfo", BindingFlags.Instance | BindingFlags.NonPublic); + + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsDNSCachingSetup))] + public void DNSCachingIsSupportedFlag() + { + string expectedDNSCachingSupportedCR = DataTestUtility.IsDNSCachingSupportedCR ? "true" : "false"; + string expectedDNSCachingSupportedTR = DataTestUtility.IsDNSCachingSupportedTR ? "true" : "false"; + + using(SqlConnection connection = new SqlConnection(DataTestUtility.DNSCachingConnString)) + { + connection.Open(); + + IDictionary dictionary = connection.RetrieveInternalInfo(); + bool ret = dictionary.TryGetValue("SQLDNSCachingSupportedState", out object val); + ret = dictionary.TryGetValue("SQLDNSCachingSupportedStateBeforeRedirect", out object valBeforeRedirect); + string isSupportedStateTR = (string)val; + string isSupportedStateCR = (string)valBeforeRedirect; + Assert.Equal(expectedDNSCachingSupportedCR, isSupportedStateCR); + Assert.Equal(expectedDNSCachingSupportedTR, isSupportedStateTR); + } + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsDNSCachingSetup))] + public void DNSCachingGetDNSInfo() + { + using(SqlConnection connection = new SqlConnection(DataTestUtility.DNSCachingConnString)) + { + connection.Open(); + } + + var SQLFallbackDNSCacheInstance = SQLFallbackDNSCacheType.GetProperty("Instance", BindingFlags.Static | BindingFlags.Public).GetValue(null); + + var serverList = new List>(); + serverList.Add(new KeyValuePair(DataTestUtility.DNSCachingServerCR, DataTestUtility.IsDNSCachingSupportedCR)); + serverList.Add(new KeyValuePair(DataTestUtility.DNSCachingServerTR, DataTestUtility.IsDNSCachingSupportedTR)); + + foreach(var server in serverList) + { + object[] parameters; + bool ret; + + if (!string.IsNullOrEmpty(server.Key)) + { + parameters = new object[] { server.Key, null }; + ret = (bool)SQLFallbackDNSCacheGetDNSInfo.Invoke(SQLFallbackDNSCacheInstance, parameters); + + if (server.Value) + { + Assert.NotNull(parameters[1]); + Assert.Equal(server.Key, (string)SQLDNSInfoType.GetProperty("FQDN").GetValue(parameters[1])); + } + else + { + Assert.Null(parameters[1]); + } + } + } + } + } +} diff --git a/tools/props/Versions.props b/tools/props/Versions.props index e10ec44a73..817a1dbc1a 100644 --- a/tools/props/Versions.props +++ b/tools/props/Versions.props @@ -9,7 +9,7 @@ - 2.0.0-preview1.20141.10 + 2.0.0 4.3.1 4.3.0 @@ -24,7 +24,7 @@ 4.7.0 - 2.0.0-preview1.20141.10 + 2.0.0 4.7.0 4.7.0 4.7.0 diff --git a/tools/specs/Microsoft.Data.SqlClient.nuspec b/tools/specs/Microsoft.Data.SqlClient.nuspec index f6be6bab6b..3e09d83333 100644 --- a/tools/specs/Microsoft.Data.SqlClient.nuspec +++ b/tools/specs/Microsoft.Data.SqlClient.nuspec @@ -27,13 +27,13 @@ When using NuGet 3.x this package requires at least version 3.4. sqlclient microsoft.data.sqlclient - + - + @@ -45,7 +45,7 @@ When using NuGet 3.x this package requires at least version 3.4. - + @@ -57,7 +57,7 @@ When using NuGet 3.x this package requires at least version 3.4. - +