diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/ServerInfo.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/ServerInfo.cs index c9f69f165f..45519cbf93 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/ServerInfo.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/ServerInfo.cs @@ -73,7 +73,7 @@ internal ServerInfo( PreRoutingServerName = preRoutingServerName; UserProtocol = TdsEnums.TCP; SetDerivedNames(UserProtocol, UserServerName); - ResolvedDatabaseName = userOptions.InitialCatalog; + ResolvedDatabaseName = routing?.DatabaseName ?? userOptions.InitialCatalog; ServerSPN = serverSpn; } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs index 389dfeb822..7f254046d4 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Connection/SqlConnectionInternal.cs @@ -223,7 +223,7 @@ internal class SqlConnectionInternal : DbConnectionInternal, IDisposable private string _currentLanguage; - private int _currentPacketSize; + /// /// Pool this connection is associated with, if any. @@ -262,7 +262,7 @@ internal class SqlConnectionInternal : DbConnectionInternal, IDisposable /// private DbConnectionPoolAuthenticationContext _newDbConnectionPoolAuthenticationContext; - private Guid _originalClientConnectionId = Guid.Empty; + private string _originalDatabase; @@ -280,7 +280,7 @@ internal class SqlConnectionInternal : DbConnectionInternal, IDisposable // @TODO: Rename to match naming conventions (remove f prefix) private readonly bool _fResetConnection; - private string _routingDestination = null; + // @TODO: Rename to match naming conventions private bool _SQLDNSRetryEnabled = false; @@ -669,17 +669,9 @@ internal override bool IsTransactionRoot get => DelegatedTransaction?.IsActive == true; } - // @TODO: Make auto-property - internal Guid OriginalClientConnectionId - { - get => _originalClientConnectionId; - } + internal Guid OriginalClientConnectionId { get; private set; } = Guid.Empty; - // @TODO: Make auto-property - internal int PacketSize - { - get => _currentPacketSize; - } + internal int PacketSize { get; private set; } // @TODO: Make auto-property internal TdsParser Parser @@ -706,11 +698,7 @@ internal SqlConnectionPoolGroupProviderInfo PoolGroupProviderInfo /// internal byte[] PromotedDtcToken { get; private set; } - // @TODO: Make auto-property - internal string RoutingDestination - { - get => _routingDestination; - } + internal string RoutingDestination { get; private set; } internal RoutingInfo RoutingInfo { get; private set; } = null; @@ -1187,7 +1175,7 @@ internal void OnEnvChange(SqlEnvChange rec) break; case TdsEnums.ENV_PACKETSIZE: - _currentPacketSize = int.Parse(rec._newValue, CultureInfo.InvariantCulture); + PacketSize = int.Parse(rec._newValue, CultureInfo.InvariantCulture); break; case TdsEnums.ENV_COLLATION: @@ -1267,6 +1255,23 @@ internal void OnEnvChange(SqlEnvChange rec) RoutingInfo = rec._newRoutingInfo; break; + case TdsEnums.ENV_ENHANCEDROUTING: + SqlClientEventSource.Log.TryAdvancedTraceEvent( + $"SqlInternalConnectionTds.OnEnvChange | ADV | " + + $"Object ID {ObjectID}, " + + $"Received enhanced routing info"); + + if (string.IsNullOrEmpty(rec._newRoutingInfo.ServerName) || + string.IsNullOrEmpty(rec._newRoutingInfo.DatabaseName) || + rec._newRoutingInfo.Protocol != 0 || + rec._newRoutingInfo.Port == 0) + { + throw SQL.ROR_InvalidEnhancedRoutingInfo(this); + } + + RoutingInfo = rec._newRoutingInfo; + break; + default: Debug.Fail("Missed token in EnvChange!"); break; @@ -1733,11 +1738,6 @@ internal void OnFeatureExtAck(int featureId, byte[] data) } case TdsEnums.FEATUREEXT_ENHANCEDROUTINGSUPPORT: { - SqlClientEventSource.Log.TryAdvancedTraceEvent( - $"SqlInternalConnectionTds.OnFeatureExtAck | ADV | " + - $"Object ID {ObjectID}, " + - $"Received feature extension acknowledgement for ENHANCEDROUTINGSUPPORT"); - if (data.Length != 1) { SqlClientEventSource.Log.TryTraceEvent( @@ -1750,6 +1750,12 @@ internal void OnFeatureExtAck(int featureId, byte[] data) // A value of 1 indicates that the server supports the feature. IsEnhancedRoutingSupportEnabled = data[0] == 1; + + SqlClientEventSource.Log.TryAdvancedTraceEvent( + $"SqlInternalConnectionTds.OnFeatureExtAck | ADV | " + + $"Object ID {ObjectID}, " + + $"Received feature extension acknowledgement for " + + $"ENHANCEDROUTINGSUPPORT = {IsEnhancedRoutingSupportEnabled}"); break; } case TdsEnums.FEATUREEXT_USERAGENT: @@ -2967,7 +2973,7 @@ private void Login( // Gather all the settings the user set in the connection string or properties and do // the login CurrentDatabase = server.ResolvedDatabaseName; - _currentPacketSize = ConnectionOptions.PacketSize; + PacketSize = ConnectionOptions.PacketSize; _currentLanguage = ConnectionOptions.CurrentLanguage; int timeoutInSeconds = 0; @@ -3018,7 +3024,7 @@ private void Login( // Treat AD Integrated like Windows integrated when against a non-FedAuth endpoint login.useSSPI = ConnectionOptions.IntegratedSecurity || (ConnectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated && !_fedAuthRequired); - login.packetSize = _currentPacketSize; + login.packetSize = PacketSize; login.newPassword = newPassword; login.readOnlyIntent = ConnectionOptions.ApplicationIntent == ApplicationIntent.ReadOnly; login.credential = _credential; @@ -3302,11 +3308,11 @@ private void LoginNoFailover( serverInfo.ResolvedServerName, serverInfo.ServerSPN); _timeoutErrorInternal.SetInternalSourceType(SqlConnectionInternalSourceType.RoutingDestination); - _originalClientConnectionId = _clientConnectionId; - _routingDestination = serverInfo.UserServerName; + OriginalClientConnectionId = _clientConnectionId; + RoutingDestination = serverInfo.UserServerName; // Restore properties that could be changed by the environment tokens - _currentPacketSize = ConnectionOptions.PacketSize; + PacketSize = ConnectionOptions.PacketSize; _currentLanguage = _originalLanguage = ConnectionOptions.CurrentLanguage; CurrentDatabase = _originalDatabase = ConnectionOptions.InitialCatalog; ServerProvidedFailoverPartner = null; @@ -3605,11 +3611,11 @@ private void LoginWithFailover( currentServerInfo.ResolvedServerName, currentServerInfo.ServerSPN); _timeoutErrorInternal.SetInternalSourceType(SqlConnectionInternalSourceType.RoutingDestination); - _originalClientConnectionId = _clientConnectionId; - _routingDestination = currentServerInfo.UserServerName; + OriginalClientConnectionId = _clientConnectionId; + RoutingDestination = currentServerInfo.UserServerName; // Restore properties that could be changed by the environment tokens - _currentPacketSize = connectionOptions.PacketSize; + PacketSize = connectionOptions.PacketSize; _currentLanguage = _originalLanguage = ConnectionOptions.CurrentLanguage; CurrentDatabase = _originalDatabase = connectionOptions.InitialCatalog; ServerProvidedFailoverPartner = null; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs index 020035841b..c2d7d26467 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -1182,6 +1182,15 @@ internal static Exception ROR_InvalidRoutingInfo(SqlConnectionInternal internalC return exc; } + internal static Exception ROR_InvalidEnhancedRoutingInfo(SqlConnectionInternal internalConnection) + { + SqlErrorCollection errors = new SqlErrorCollection(); + errors.Add(new SqlError(0, (byte)0x00, TdsEnums.FATAL_ERROR_CLASS, null, (StringsHelper.GetString(Strings.SQLROR_InvalidEnhancedRoutingInfo)), "", 0)); + SqlException exc = SqlException.CreateException(errors, null, internalConnection, innerException: null, batchCommand: null); + exc._doNotReconnect = true; + return exc; + } + internal static Exception ROR_TimeoutAfterRoutingInfo(SqlConnectionInternal internalConnection) { SqlErrorCollection errors = new SqlErrorCollection(); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs index a0c6e8d19a..cf8f4b3893 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsEnums.cs @@ -186,29 +186,7 @@ internal static class TdsEnums public const byte ENV_SPRESETCONNECTIONACK = 18; // SP_Reset_Connection ack public const byte ENV_USERINSTANCE = 19; // User Instance public const byte ENV_ROUTING = 20; // Routing (ROR) information - - public enum EnvChangeType : byte - { - ENVCHANGE_DATABASE = ENV_DATABASE, - ENVCHANGE_LANG = ENV_LANG, - ENVCHANGE_CHARSET = ENV_CHARSET, - ENVCHANGE_PACKETSIZE = ENV_PACKETSIZE, - ENVCHANGE_LOCALEID = ENV_LOCALEID, - ENVCHANGE_COMPFLAGS = ENV_COMPFLAGS, - ENVCHANGE_COLLATION = ENV_COLLATION, - ENVCHANGE_BEGINTRAN = ENV_BEGINTRAN, - ENVCHANGE_COMMITTRAN = ENV_COMMITTRAN, - ENVCHANGE_ROLLBACKTRAN = ENV_ROLLBACKTRAN, - ENVCHANGE_ENLISTDTC = ENV_ENLISTDTC, - ENVCHANGE_DEFECTDTC = ENV_DEFECTDTC, - ENVCHANGE_LOGSHIPNODE = ENV_LOGSHIPNODE, - ENVCHANGE_PROMOTETRANSACTION = ENV_PROMOTETRANSACTION, - ENVCHANGE_TRANSACTIONMANAGERADDRESS = ENV_TRANSACTIONMANAGERADDRESS, - ENVCHANGE_TRANSACTIONENDED = ENV_TRANSACTIONENDED, - ENVCHANGE_SPRESETCONNECTIONACK = ENV_SPRESETCONNECTIONACK, - ENVCHANGE_USERINSTANCE = ENV_USERINSTANCE, - ENVCHANGE_ROUTING = ENV_ROUTING - } + public const byte ENV_ENHANCEDROUTING = 21; // Enhanced Routing (ROR) information // done status stream bit masks public const int DONE_MORE = 0x0001; // more command results coming diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index 143795e9ba..76e5d775f2 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -3385,6 +3385,7 @@ private TdsOperationStatus TryProcessEnvChange(int tokenLength, TdsParserStateOb break; case TdsEnums.ENV_ROUTING: + { ushort newLength; result = stateObj.TryReadUInt16(out newLength); if (result != TdsOperationStatus.Done) @@ -3428,8 +3429,22 @@ private TdsOperationStatus TryProcessEnvChange(int tokenLength, TdsParserStateOb { return result; } - env._length = env._newLength + oldLength + 5; // 5=2*sizeof(UInt16)+sizeof(byte) [token+newLength+oldLength] + // Set the total length of the token + // total = headers + size of new value + size of old value + // headers = token id+newLengthHeader+oldLengthHeader = sizeof(byte) + 2*sizeof(UInt16) = 5 + env._length = env._newLength + oldLength + 5; + break; + } + + case TdsEnums.ENV_ENHANCEDROUTING: + { + result = TryProcessEnhancedRoutingToken(env, stateObj); + if (result != TdsOperationStatus.Done) + { + return result; + } break; + } default: Debug.Fail("Unknown environment change token: " + env._type); @@ -3442,6 +3457,95 @@ private TdsOperationStatus TryProcessEnvChange(int tokenLength, TdsParserStateOb return TdsOperationStatus.Done; } + /// + /// Processes an enhanced routing ENVCHANGE token from the TDS stream. + /// This token contains the routing information for the connection, including the protocol, + /// port, server name, and database name. The enhanced routing token has the following structure: + /// + /// NewValueLength (USHORT) - length of the new value + /// Protocol (BYTE) - routing protocol (must be 0 = TCP) + /// Port (USHORT) - routing port number + /// AlternateServerNameLength (USHORT) - length of the alternate server name in characters + /// AlternateServerName (UNICODE_STRING) - the server name to route to + /// AlternateDatabaseNameLength (USHORT) - length of the alternate database name in characters + /// AlternateDatabaseName (UNICODE_STRING) - the database name to route to + /// OldValueLength (USHORT) - length of the old value + /// OldValue (BYTE[]) - old value (skipped) + /// + /// + private TdsOperationStatus TryProcessEnhancedRoutingToken(SqlEnvChange env, TdsParserStateObject stateObj) + { + ushort newLength; + TdsOperationStatus result = stateObj.TryReadUInt16(out newLength); + if (result != TdsOperationStatus.Done) + { + return result; + } + env._newLength = newLength; + + byte protocol; + result = stateObj.TryReadByte(out protocol); + if (result != TdsOperationStatus.Done) + { + return result; + } + + ushort port; + result = stateObj.TryReadUInt16(out port); + if (result != TdsOperationStatus.Done) + { + return result; + } + + ushort serverLen; + result = stateObj.TryReadUInt16(out serverLen); + if (result != TdsOperationStatus.Done) + { + return result; + } + + string serverName; + result = stateObj.TryReadString(serverLen, out serverName); + if (result != TdsOperationStatus.Done) + { + return result; + } + + ushort databaseLen; + result = stateObj.TryReadUInt16(out databaseLen); + if (result != TdsOperationStatus.Done) + { + return result; + } + + string databaseName; + result = stateObj.TryReadString(databaseLen, out databaseName); + if (result != TdsOperationStatus.Done) + { + return result; + } + + env._newRoutingInfo = new RoutingInfo(protocol, port, serverName, databaseName); + + ushort oldLength; + result = stateObj.TryReadUInt16(out oldLength); + if (result != TdsOperationStatus.Done) + { + return result; + } + + result = stateObj.TrySkipBytes(oldLength); + if (result != TdsOperationStatus.Done) + { + return result; + } + + // 5 = 2*sizeof(UInt16)+sizeof(byte) [token+newLength+oldLength] + env._length = env._newLength + oldLength + 5; + + return TdsOperationStatus.Done; + } + private TdsOperationStatus TryReadTwoBinaryFields(SqlEnvChange env, TdsParserStateObject stateObj) { // Used by ProcessEnvChangeToken diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserHelperClasses.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserHelperClasses.cs index f189030d1e..7cc2c2066c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserHelperClasses.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserHelperClasses.cs @@ -89,11 +89,18 @@ internal class RoutingInfo internal ushort Port { get; private set; } internal string ServerName { get; private set; } - internal RoutingInfo(byte protocol, ushort port, string servername) + /// + /// The DatabaseName property is only used when routing via an EnhancedRouting ENVCHANGE token. + /// It is not used when routing via the normal Routing ENVCHANGE token. + /// + internal string DatabaseName { get; private set; } + + internal RoutingInfo(byte protocol, ushort port, string serverName, string databaseName = null) { Protocol = protocol; Port = port; - ServerName = servername; + ServerName = serverName; + DatabaseName = databaseName; } } diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs b/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs index 10dbae761d..8d54324b8a 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.Designer.cs @@ -11794,6 +11794,15 @@ internal static string SQLROR_InvalidRoutingInfo { } } + /// + /// Looks up a localized string similar to Invalid enhanced routing information received.. + /// + internal static string SQLROR_InvalidEnhancedRoutingInfo { + get { + return ResourceManager.GetString("SQLROR_InvalidEnhancedRoutingInfo", resourceCulture); + } + } + /// /// Looks up a localized string similar to Too many redirections have occurred. Only {0} redirections per login is allowed.. /// diff --git a/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx b/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx index bb81f1c357..01fc64c57e 100644 --- a/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx +++ b/src/Microsoft.Data.SqlClient/src/Resources/Strings.resx @@ -4335,6 +4335,9 @@ Invalid routing information received. + + Invalid enhanced routing information received. + Server provided routing information, but timeout already expired. diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionEnhancedRoutingTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionEnhancedRoutingTests.cs new file mode 100644 index 0000000000..2ec23e6ea0 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionEnhancedRoutingTests.cs @@ -0,0 +1,201 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Data; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient.Connection; +using Microsoft.SqlServer.TDS.Servers; +using Xunit; + +namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests; + +/// +/// Tests connection routing using the enhanced routing feature extension and envchange token +/// +[Collection("SimulatedServerTests")] +public class ConnectionEnhancedRoutingTests +{ + [Fact] + public void RoutedConnection() + { + // Arrange + using TdsServer server = new(new()); + server.Start(); + + string routingDatabaseName = Guid.NewGuid().ToString(); + bool clientProvidedCorrectDatabase = false; + server.OnLogin7Validated = loginToken => + { + clientProvidedCorrectDatabase = routingDatabaseName == loginToken.Database; + }; + + using RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + RoutingDatabaseName = routingDatabaseName, + RequireReadOnly = false + }); + router.Start(); + router.EnhancedRoutingBehavior = FeatureExtensionBehavior.Enabled; + + string connectionString = (new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{router.EndPoint.Port}", + Encrypt = false, + ConnectTimeout = 10000 + }).ConnectionString; + + // Act + using SqlConnection connection = new(connectionString); + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", ((SqlConnectionInternal)connection.InnerConnection).RoutingDestination); + Assert.Equal(routingDatabaseName, connection.Database); + Assert.True(clientProvidedCorrectDatabase); + + Assert.Equal(1, router.PreLoginCount); + Assert.Equal(1, server.PreLoginCount); + } + + [Fact] + public async Task RoutedAsyncConnection() + { + // Arrange + using TdsServer server = new(new()); + server.Start(); + + string routingDatabaseName = Guid.NewGuid().ToString(); + bool clientProvidedCorrectDatabase = false; + server.OnLogin7Validated = loginToken => + { + clientProvidedCorrectDatabase = routingDatabaseName == loginToken.Database; + }; + + using RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + RoutingDatabaseName = routingDatabaseName, + RequireReadOnly = false + }); + router.Start(); + router.EnhancedRoutingBehavior = FeatureExtensionBehavior.Enabled; + + string connectionString = (new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{router.EndPoint.Port}", + Encrypt = false, + ConnectTimeout = 10000 + }).ConnectionString; + + // Act + using SqlConnection connection = new(connectionString); + await connection.OpenAsync(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", ((SqlConnectionInternal)connection.InnerConnection).RoutingDestination); + Assert.Equal(routingDatabaseName, connection.Database); + Assert.True(clientProvidedCorrectDatabase); + + Assert.Equal(1, router.PreLoginCount); + Assert.Equal(1, server.PreLoginCount); + } + + [Fact] + public void ServerIgnoresEnhancedRoutingRequest() + { + // Arrange + using TdsServer server = new(new()); + server.Start(); + + string routingDatabaseName = Guid.NewGuid().ToString(); + bool clientProvidedCorrectDatabase = false; + server.OnLogin7Validated = loginToken => + { + clientProvidedCorrectDatabase = null == loginToken.Database; + }; + + using RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + RequireReadOnly = false + }); + router.Start(); + router.EnhancedRoutingBehavior = FeatureExtensionBehavior.DoNotAcknowledge; + + string connectionString = (new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{router.EndPoint.Port}", + Encrypt = false, + ConnectTimeout = 10000 + }).ConnectionString; + + // Act + using SqlConnection connection = new(connectionString); + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", ((SqlConnectionInternal)connection.InnerConnection).RoutingDestination); + Assert.Equal("master", connection.Database); + Assert.True(clientProvidedCorrectDatabase); + + Assert.Equal(1, router.PreLoginCount); + Assert.Equal(1, server.PreLoginCount); + } + + [Fact] + public void ServerRejectsEnhancedRoutingRequest() + { + // Arrange + using TdsServer server = new(new()); + server.Start(); + + string routingDatabaseName = Guid.NewGuid().ToString(); + bool clientProvidedCorrectDatabase = false; + server.OnLogin7Validated = loginToken => + { + clientProvidedCorrectDatabase = null == loginToken.Database; + }; + + using RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + RequireReadOnly = false + }); + router.Start(); + router.EnhancedRoutingBehavior = FeatureExtensionBehavior.Disabled; + + string connectionString = (new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{router.EndPoint.Port}", + Encrypt = false, + ConnectTimeout = 10000 + }).ConnectionString; + + // Act + using SqlConnection connection = new(connectionString); + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", ((SqlConnectionInternal)connection.InnerConnection).RoutingDestination); + Assert.Equal("master", connection.Database); + Assert.True(clientProvidedCorrectDatabase); + + Assert.Equal(1, router.PreLoginCount); + Assert.Equal(1, server.PreLoginCount); + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/FeatureExtensionNegotiationTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/FeatureExtensionNegotiationTests.cs index be1a2ee4e2..e6342c4fa8 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/FeatureExtensionNegotiationTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/FeatureExtensionNegotiationTests.cs @@ -14,12 +14,12 @@ namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests; [Collection("SimulatedServerTests")] -public class FeatureExtensionNegotiationTests : IClassFixture +public class FeatureExtensionNegotiationTests : IClassFixture { private TdsServer _server; private string _connectionString; - public FeatureExtensionNegotiationTests(SimulatedServerFixture fixture) + public FeatureExtensionNegotiationTests(TdsServerFixture fixture) { _server = fixture.TdsServer; SqlConnectionStringBuilder builder = new() @@ -33,13 +33,13 @@ public FeatureExtensionNegotiationTests(SimulatedServerFixture fixture) } [Theory] - [InlineData(FeatureExtensionEnablementTriState.Enabled, (byte[])[1])] - [InlineData(FeatureExtensionEnablementTriState.Disabled, (byte[])[0])] - [InlineData(FeatureExtensionEnablementTriState.DoNotAcknowledge, null)] - public void EnhancedRoutingNegotiationTest(FeatureExtensionEnablementTriState serverBehavior, byte[]? expectedAckData) + [InlineData(FeatureExtensionBehavior.Enabled, (byte[])[1])] + [InlineData(FeatureExtensionBehavior.Disabled, (byte[])[0])] + [InlineData(FeatureExtensionBehavior.DoNotAcknowledge, null)] + public void EnhancedRoutingNegotiationTest(FeatureExtensionBehavior serverBehavior, byte[]? expectedAckData) { // Arrange - _server.EnableEnhancedRouting = serverBehavior; + _server.EnhancedRoutingBehavior = serverBehavior; bool clientRequestedFeatureExtension = false; _server.OnLogin7Validated = loginToken => @@ -48,8 +48,6 @@ public void EnhancedRoutingNegotiationTest(FeatureExtensionEnablementTriState se .OfType() .FirstOrDefault(t => t.FeatureID == TDSFeatureID.EnhancedRoutingSupport); - - // Test should fail if no EnhancedRoutingSupport FE token is found Assert.NotNull(token); Assert.Equal((byte)TDSFeatureID.EnhancedRoutingSupport, (byte)token.FeatureID); @@ -78,11 +76,10 @@ public void EnhancedRoutingNegotiationTest(FeatureExtensionEnablementTriState se // Act sqlConnection.Open(); - // Assert Assert.True(clientRequestedFeatureExtension); - if (serverBehavior == FeatureExtensionEnablementTriState.DoNotAcknowledge) + if (serverBehavior == FeatureExtensionBehavior.DoNotAcknowledge) { // In DoNotAcknowledge mode, server should not acknowledge the feature extension even if client requested it Assert.False(serverAcknowledgedFeatureExtension); @@ -92,7 +89,7 @@ public void EnhancedRoutingNegotiationTest(FeatureExtensionEnablementTriState se Assert.True(serverAcknowledgedFeatureExtension); } - if (serverBehavior == FeatureExtensionEnablementTriState.Enabled) + if (serverBehavior == FeatureExtensionBehavior.Enabled) { Assert.True(((SqlConnectionInternal)sqlConnection.InnerConnection).IsEnhancedRoutingSupportEnabled); } @@ -101,21 +98,4 @@ public void EnhancedRoutingNegotiationTest(FeatureExtensionEnablementTriState se Assert.False(((SqlConnectionInternal)sqlConnection.InnerConnection).IsEnhancedRoutingSupportEnabled); } } - - - public class SimulatedServerFixture : IDisposable - { - public SimulatedServerFixture() - { - TdsServer = new TdsServer(); - TdsServer.Start(); - } - - public void Dispose() - { - TdsServer.Dispose(); - } - - public TdsServer TdsServer { get; private set; } - } } diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/Fixtures/TdsServerFixture.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/Fixtures/TdsServerFixture.cs new file mode 100644 index 0000000000..2b0f5fc2e5 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/Fixtures/TdsServerFixture.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.SqlServer.TDS.Servers; + +/// +/// An xunit test fixture that manages the lifecycle of a TdsServer. +/// +public class TdsServerFixture : IDisposable +{ + public TdsServerFixture() + { + TdsServer = new TdsServer(); + TdsServer.Start(); + } + + public void Dispose() + { + TdsServer.Dispose(); + } + + public TdsServer TdsServer { get; } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/ITDSServerSession.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/ITDSServerSession.cs index c8728da6ca..7d80ecc53a 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/ITDSServerSession.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/ITDSServerSession.cs @@ -102,6 +102,6 @@ public interface ITDSServerSession /// /// Indicates whether the client supports Enhanced Routing /// - bool IsEnhancedRoutingSupportEnabled { get; set; } + bool IsEnhancedRoutingSupportRequested { get; set; } } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FeatureExtensionEnablementTriState.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FeatureExtensionBehavior.cs similarity index 86% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FeatureExtensionEnablementTriState.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FeatureExtensionBehavior.cs index 35103b2e7c..d318c90490 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FeatureExtensionEnablementTriState.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FeatureExtensionBehavior.cs @@ -5,7 +5,7 @@ namespace Microsoft.SqlServer.TDS.Servers { - public enum FeatureExtensionEnablementTriState + public enum FeatureExtensionBehavior { Disabled, Enabled, diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs index 921db64b47..d2c81e6ca1 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs @@ -63,7 +63,7 @@ public delegate void OnAuthenticationCompletedDelegate( /// /// Property for setting server enhanced routing enablement state. /// - public FeatureExtensionEnablementTriState EnableEnhancedRouting { get; set; } = FeatureExtensionEnablementTriState.Disabled; + public FeatureExtensionBehavior EnhancedRoutingBehavior { get; set; } = FeatureExtensionBehavior.Disabled; /// /// Property for setting server version for vector feature extension. @@ -346,7 +346,7 @@ public virtual TDSMessageCollection OnLogin7Request(ITDSServerSession session, T case TDSFeatureID.EnhancedRoutingSupport: { - session.IsEnhancedRoutingSupportEnabled = true; + session.IsEnhancedRoutingSupportRequested = true; break; } @@ -817,20 +817,13 @@ protected void CheckUserAgentSupport(ITDSServerSession session, TDSMessage respo /// Response message protected void CheckEnhancedRoutingSupport(ITDSServerSession session, TDSMessage responseMessage) { - if (session.IsEnhancedRoutingSupportEnabled && - EnableEnhancedRouting != FeatureExtensionEnablementTriState.DoNotAcknowledge) + if (session.IsEnhancedRoutingSupportRequested && + EnhancedRoutingBehavior != FeatureExtensionBehavior.DoNotAcknowledge) { // Create ack data (1 byte: IsEnabled) - byte[] data = new byte[1]; - - if (EnableEnhancedRouting == FeatureExtensionEnablementTriState.Enabled) - { - data[0] = 1; - } - else - { - data[0] = 0; - } + byte[] data = EnhancedRoutingBehavior == FeatureExtensionBehavior.Enabled + ? [1] + : [0]; // Create enhanced routing support as a generic feature extension option TDSFeatureExtAckGenericOption enhancedRoutingSupportOption = new TDSFeatureExtAckGenericOption(TDSFeatureID.EnhancedRoutingSupport, (uint)data.Length, data); diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServerSession.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServerSession.cs index bc0b7e087c..b50b8fda97 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServerSession.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServerSession.cs @@ -132,7 +132,7 @@ public class GenericTdsServerSession : ITDSServerSession /// /// Indicates whether this session supports enhanced routing /// - public bool IsEnhancedRoutingSupportEnabled { get; set; } + public bool IsEnhancedRoutingSupportRequested { get; set; } #region Session Options diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServer.cs index 8e119a54cd..cdfc754e49 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServer.cs @@ -42,27 +42,20 @@ public override TDSMessageCollection OnPreLoginRequest(ITDSServerSession session // Delegate to the base class TDSMessageCollection response = base.OnPreLoginRequest(session, request); - // Check if arguments are of the routing server - if (Arguments is RoutingTdsServerArguments) + // Check if routing is configured during login + if (Arguments.RouteOnPacket == TDSMessageType.TDS7Login) { - // Cast to routing server arguments - RoutingTdsServerArguments serverArguments = Arguments as RoutingTdsServerArguments; - - // Check if routing is configured during login - if (serverArguments.RouteOnPacket == TDSMessageType.TDS7Login) + // Check if pre-login response is contained inside the first message + if (response.Count > 0 && response[0].Any(t => t is TDSPreLoginToken)) { - // Check if pre-login response is contained inside the first message - if (response.Count > 0 && response[0].Any(t => t is TDSPreLoginToken)) - { - // Find the first prelogin token - TDSPreLoginToken preLoginResponse = (TDSPreLoginToken)response[0].Where(t => t is TDSPreLoginToken).First(); - - // Inflate pre-login request from the message - TDSPreLoginToken preLoginRequest = request[0] as TDSPreLoginToken; - - // Update MARS with the requested value - preLoginResponse.IsMARS = preLoginRequest.IsMARS; - } + // Find the first prelogin token + TDSPreLoginToken preLoginResponse = (TDSPreLoginToken)response[0].Where(t => t is TDSPreLoginToken).First(); + + // Inflate pre-login request from the message + TDSPreLoginToken preLoginRequest = request[0] as TDSPreLoginToken; + + // Update MARS with the requested value + preLoginResponse.IsMARS = preLoginRequest.IsMARS; } } @@ -77,48 +70,41 @@ public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, // Inflate login7 request from the message TDSLogin7Token loginRequest = request[0] as TDSLogin7Token; - // Check if arguments are of the routing server - if (Arguments is RoutingTdsServerArguments) + // Check filter + if (Arguments.RequireReadOnly && (loginRequest.TypeFlags.ReadOnlyIntent != TDSLogin7TypeFlagsReadOnlyIntent.ReadOnly)) { - // Cast to routing server arguments - RoutingTdsServerArguments ServerArguments = Arguments as RoutingTdsServerArguments; + // Log request + TDSUtilities.Log(Arguments.Log, "Request", loginRequest); - // Check filter - if (ServerArguments.RequireReadOnly && (loginRequest.TypeFlags.ReadOnlyIntent != TDSLogin7TypeFlagsReadOnlyIntent.ReadOnly)) - { - // Log request - TDSUtilities.Log(Arguments.Log, "Request", loginRequest); + // Prepare ERROR token with the denial details + TDSErrorToken errorToken = new TDSErrorToken(18456, 1, 14, "Received application intent: " + loginRequest.TypeFlags.ReadOnlyIntent.ToString(), Arguments.ServerName); - // Prepare ERROR token with the denial details - TDSErrorToken errorToken = new TDSErrorToken(18456, 1, 14, "Received application intent: " + loginRequest.TypeFlags.ReadOnlyIntent.ToString(), Arguments.ServerName); + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); + // Serialize the error token into the response packet + TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); - // Serialize the error token into the response packet - TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); + // Prepare ERROR token for the final decision + errorToken = new TDSErrorToken(18456, 1, 14, "Read-Only application intent is required for routing", Arguments.ServerName); - // Prepare ERROR token for the final decision - errorToken = new TDSErrorToken(18456, 1, 14, "Read-Only application intent is required for routing", Arguments.ServerName); + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); + // Serialize the error token into the response packet + responseMessage.Add(errorToken); - // Serialize the error token into the response packet - responseMessage.Add(errorToken); + // Create DONE token + TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); - // Create DONE token - TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); + // Log response + TDSUtilities.Log(Arguments.Log, "Response", doneToken); - // Log response - TDSUtilities.Log(Arguments.Log, "Response", doneToken); + // Serialize DONE token into the response packet + responseMessage.Add(doneToken); - // Serialize DONE token into the response packet - responseMessage.Add(doneToken); - - // Return a single message in the collection - return new TDSMessageCollection(responseMessage); - } + // Return a single message in the collection + return new TDSMessageCollection(responseMessage); } // Delegate to the base class @@ -135,44 +121,37 @@ public override TDSMessageCollection OnSQLBatchRequest(ITDSServerSession session // Delegate to the base class to produce the response first TDSMessageCollection batchResponse = base.OnSQLBatchRequest(session, request); - // Check if arguments are of routing server - if (Arguments is RoutingTdsServerArguments) + // Check routing condition + if (Arguments.RouteOnPacket == TDSMessageType.SQLBatch) { - // Cast to routing server arguments - RoutingTdsServerArguments ServerArguments = Arguments as RoutingTdsServerArguments; + // Construct routing token + TDSPacketToken routingToken = CreateRoutingToken(); - // Check routing condition - if (ServerArguments.RouteOnPacket == TDSMessageType.SQLBatch) - { - // Construct routing token - TDSPacketToken routingToken = CreateRoutingToken(); + // Log response + TDSUtilities.Log(Arguments.Log, "Response", routingToken); - // Log response - TDSUtilities.Log(Arguments.Log, "Response", routingToken); - - // Insert the routing token at the beginning of the response - batchResponse[0].Insert(0, routingToken); - } - else - { - // Get the first response message - TDSMessage responseMessage = batchResponse[0]; + // Insert the routing token at the beginning of the response + batchResponse[0].Insert(0, routingToken); + } + else + { + // Get the first response message + TDSMessage responseMessage = batchResponse[0]; - // Reset the content of the first message - responseMessage.Clear(); + // Reset the content of the first message + responseMessage.Clear(); - // Prepare ERROR token with the denial details - responseMessage.Add(new TDSErrorToken(11111, 1, 14, "Client should have been routed by now", Arguments.ServerName)); + // Prepare ERROR token with the denial details + responseMessage.Add(new TDSErrorToken(11111, 1, 14, "Client should have been routed by now", Arguments.ServerName)); - // Log response - TDSUtilities.Log(Arguments.Log, "Response", responseMessage[0]); + // Log response + TDSUtilities.Log(Arguments.Log, "Response", responseMessage[0]); - // Prepare DONE token - responseMessage.Add(new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error)); + // Prepare DONE token + responseMessage.Add(new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error)); - // Log response - TDSUtilities.Log(Arguments.Log, "Response", responseMessage[1]); - } + // Log response + TDSUtilities.Log(Arguments.Log, "Response", responseMessage[1]); } // Register only one message with the collection @@ -187,41 +166,34 @@ protected override TDSMessageCollection OnAuthenticationCompleted(ITDSServerSess // Delegate to the base class TDSMessageCollection responseMessageCollection = base.OnAuthenticationCompleted(session); - // Check if arguments are of routing server - if (Arguments is RoutingTdsServerArguments) + // Check routing condition + if (Arguments.RouteOnPacket == TDSMessageType.TDS7Login) { - // Cast to routing server arguments - RoutingTdsServerArguments serverArguments = Arguments as RoutingTdsServerArguments; + // Construct routing token + TDSPacketToken routingToken = CreateRoutingToken(); - // Check routing condition - if (serverArguments.RouteOnPacket == TDSMessageType.TDS7Login) - { - // Construct routing token - TDSPacketToken routingToken = CreateRoutingToken(); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", routingToken); + // Log response + TDSUtilities.Log(Arguments.Log, "Response", routingToken); - // Get the first message - TDSMessage targetMessage = responseMessageCollection[0]; + // Get the first message + TDSMessage targetMessage = responseMessageCollection[0]; - // Index at which to insert the routing token - int insertIndex = targetMessage.Count - 1; + // Index at which to insert the routing token + int insertIndex = targetMessage.Count - 1; - // VSTS# 1021027 - Read-Only Routing yields TDS protocol error - // Resolution: Send TDS FeatureExtAct token before TDS ENVCHANGE token with routing information - TDSPacketToken featureExtAckToken = targetMessage.Find(t => t is TDSFeatureExtAckToken); + // VSTS# 1021027 - Read-Only Routing yields TDS protocol error + // Resolution: Send TDS FeatureExtAct token before TDS ENVCHANGE token with routing information + TDSPacketToken featureExtAckToken = targetMessage.Find(t => t is TDSFeatureExtAckToken); - // Check if found - if (featureExtAckToken != null) - { - // Find token position - insertIndex = targetMessage.IndexOf(featureExtAckToken); - } - - // Insert right before the done token - targetMessage.Insert(insertIndex, routingToken); + // Check if found + if (featureExtAckToken != null) + { + // Find token position + insertIndex = targetMessage.IndexOf(featureExtAckToken); } + + // Insert right before the done token + targetMessage.Insert(insertIndex, routingToken); } return responseMessageCollection; @@ -232,19 +204,32 @@ protected override TDSMessageCollection OnAuthenticationCompleted(ITDSServerSess /// protected TDSPacketToken CreateRoutingToken() { - // Cast to routing server arguments - RoutingTdsServerArguments ServerArguments = Arguments as RoutingTdsServerArguments; - - // Construct routing token value - TDSRoutingEnvChangeTokenValue routingInfo = new TDSRoutingEnvChangeTokenValue(); - - // Read the values and populate routing info - routingInfo.Protocol = (TDSRoutingEnvChangeTokenValueType)ServerArguments.RoutingProtocol; - routingInfo.ProtocolProperty = ServerArguments.RoutingTCPPort; - routingInfo.AlternateServer = ServerArguments.RoutingTCPHost; - - // Construct routing token - return new TDSEnvChangeToken(TDSEnvChangeTokenType.Routing, routingInfo); + if (string.IsNullOrEmpty(Arguments.RoutingDatabaseName)) + { + // Construct routing token value + TDSRoutingEnvChangeTokenValue routingInfo = new TDSRoutingEnvChangeTokenValue() + { + Protocol = (TDSRoutingEnvChangeTokenValueType)Arguments.RoutingProtocol, + ProtocolProperty = Arguments.RoutingTCPPort, + AlternateServer = Arguments.RoutingTCPHost, + }; + + // Construct routing token + return new TDSEnvChangeToken(TDSEnvChangeTokenType.Routing, routingInfo); + } else + { + // Construct enhanced routing token value + TdsEnhancedRoutingEnvChangeTokenValue routingInfo = new TdsEnhancedRoutingEnvChangeTokenValue() + { + Protocol = (TDSRoutingEnvChangeTokenValueType)Arguments.RoutingProtocol, + ProtocolProperty = Arguments.RoutingTCPPort, + AlternateServer = Arguments.RoutingTCPHost, + AlternateDatabase = Arguments.RoutingDatabaseName + }; + + // Construct routing token + return new TDSEnvChangeToken(TDSEnvChangeTokenType.EnhancedRouting, routingInfo); + } } } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServerArguments.cs index 95fe97f4f2..c7f5c85a69 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServerArguments.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServerArguments.cs @@ -24,6 +24,13 @@ public class RoutingTdsServerArguments : TdsServerArguments /// public string RoutingTCPHost { get; set; } = string.Empty; + /// + /// Setting this to a non-empty value will cause the server to include an + /// enhanced routing ENVCHANGE token in the Login Response message. + /// An empty value will include a legacy routing ENVCHANGE token. + /// + public string RoutingDatabaseName { get; set; } = string.Empty; + /// /// Packet on which routing should occur /// diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServerArguments.cs index 7ceb2e0272..f129bf189a 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServerArguments.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServerArguments.cs @@ -28,6 +28,7 @@ public class TdsServerArguments /// /// Log to which send TDS conversation /// + /// TODO: change this to expect ITestOutputHelper? public TextWriter Log { get; set; } = null; /// diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/EnvChange/TDSEnvChangeToken.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/EnvChange/TDSEnvChangeToken.cs index d8c3655037..1c2837de31 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/EnvChange/TDSEnvChangeToken.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/EnvChange/TDSEnvChangeToken.cs @@ -161,6 +161,33 @@ public override bool Inflate(Stream source) throw new Exception("Non-zero old value for routing information"); } + break; + } + case TDSEnvChangeTokenType.EnhancedRouting: + { + // Read the new value length + ushort valueLength = TDSUtilities.ReadUShort(source); + + // Update token length + tokenLength -= 2; // sizeof(ushort) + + // Instantiate new value + NewValue = new TdsEnhancedRoutingEnvChangeTokenValue(); + + // Inflate new value + if (!(NewValue as TdsEnhancedRoutingEnvChangeTokenValue).Inflate(source)) + { + // We should never reach this point + throw new Exception("Routing information inflation failed"); + } + + // Read always-zero old value unsigned short + if (TDSUtilities.ReadUShort(source) != 0) + { + // We should never reach this point + throw new Exception("Non-zero old value for routing information"); + } + break; } case TDSEnvChangeTokenType.SQLCollation: @@ -295,6 +322,29 @@ public override void Deflate(Stream destination) // Write zero for the old value length TDSUtilities.WriteUShort(cache, 0); + break; + } + case TDSEnvChangeTokenType.EnhancedRouting: + { + // Create a sub-cache to determine the value length + MemoryStream subCache = new MemoryStream(); + + // Check if new value exists + if (NewValue != null) + { + // Deflate token value + (NewValue as TdsEnhancedRoutingEnvChangeTokenValue).Deflate(subCache); + } + + // Write new value length + TDSUtilities.WriteUShort(cache, (ushort)subCache.Length); + + // Write new value + subCache.WriteTo(cache); + + // Write zero for the old value length + TDSUtilities.WriteUShort(cache, 0); + break; } case TDSEnvChangeTokenType.SQLCollation: diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/EnvChange/TDSEnvChangeTokenType.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/EnvChange/TDSEnvChangeTokenType.cs index b4e1ae96d3..5396fe9f41 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/EnvChange/TDSEnvChangeTokenType.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/EnvChange/TDSEnvChangeTokenType.cs @@ -27,6 +27,7 @@ public enum TDSEnvChangeTokenType : byte TransactionEnded = 17, ResetConnectionAcknowledgement = 18, UserInstance = 19, - Routing = 20 + Routing = 20, + EnhancedRouting = 21 } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/EnvChange/TdsEnhancedRoutingEnvChangeTokenValue.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/EnvChange/TdsEnhancedRoutingEnvChangeTokenValue.cs new file mode 100644 index 0000000000..1c6c47efc3 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/EnvChange/TdsEnhancedRoutingEnvChangeTokenValue.cs @@ -0,0 +1,130 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; + +namespace Microsoft.SqlServer.TDS.EnvChange +{ + /// + /// Token value that represents enhanced routing information + /// + public class TdsEnhancedRoutingEnvChangeTokenValue : IInflatable, IDeflatable + { + /// + /// Protocol to use when connecting to the target server + /// + public TDSRoutingEnvChangeTokenValueType Protocol { get; set; } + + /// + /// Protocol details + /// + public object ProtocolProperty { get; set; } + + /// + /// Location of the target server + /// + public string AlternateServer { get; set; } + + /// + /// Database to connect to at the target server + /// + public string AlternateDatabase { get; set; } + + /// + /// Default constructor + /// + public TdsEnhancedRoutingEnvChangeTokenValue() + { + } + + /// + /// Initialization constructor + /// + public TdsEnhancedRoutingEnvChangeTokenValue( + TDSRoutingEnvChangeTokenValueType protocol, + object protocolProperty, + string alternateServer, + string alternateDatabase) + { + Protocol = protocol; + ProtocolProperty = protocolProperty; + AlternateServer = alternateServer; + AlternateDatabase = alternateDatabase; + } + + /// + /// Inflate the token + /// + /// Stream to inflate the token from + /// TRUE if inflation is complete + public virtual bool Inflate(Stream source) + { + // Read protocol value + Protocol = (TDSRoutingEnvChangeTokenValueType)source.ReadByte(); + + // Based on the protocol type read the rest of the token + switch (Protocol) + { + case TDSRoutingEnvChangeTokenValueType.TCP: + { + // Read port + ProtocolProperty = TDSUtilities.ReadUShort(source); + AlternateServer = TDSUtilities.ReadString(source, (ushort)(TDSUtilities.ReadUShort(source) * 2)); + AlternateDatabase = TDSUtilities.ReadString(source, (ushort)(TDSUtilities.ReadUShort(source) * 2)); + + break; + } + default: + { + throw new Exception("Unrecognized routing protocol"); + } + } + + // Inflation is complete + return true; + } + + /// + /// Deflate the token + /// + /// Stream to deflate token to + public virtual void Deflate(Stream destination) + { + // Write protocol value + destination.WriteByte((byte)Protocol); + + // Based on the protocol type read the rest of the token + switch (Protocol) + { + default: + case TDSRoutingEnvChangeTokenValueType.TCP: + { + // Write port + TDSUtilities.WriteUShort(destination, (ushort)ProtocolProperty); + + // Write alternate server name length + TDSUtilities.WriteUShort(destination, (ushort)(string.IsNullOrEmpty(AlternateServer) ? 0 : AlternateServer.Length)); + + // Write alternate server name + TDSUtilities.WriteString(destination, AlternateServer); + + // Write alternate database name length + TDSUtilities.WriteUShort(destination, (ushort)(string.IsNullOrEmpty(AlternateDatabase) ? 0 : AlternateDatabase.Length)); + TDSUtilities.WriteString(destination, AlternateDatabase); + + break; + } + } + } + + /// + /// Override string representation method + /// + public override string ToString() + { + return $"Protocol: {Protocol}; Protocol Property: {ProtocolProperty}; Alternate Server: {AlternateServer}; Alternate Database: {AlternateDatabase}"; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/TDS.csproj b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/TDS.csproj index eb83deb809..6b30cbe05c 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/TDS.csproj +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS/TDS.csproj @@ -4,105 +4,7 @@ Microsoft.SqlServer.TDS {8DC9D1A0-351B-47BC-A90F-B9DA542550E9} netstandard2.0 - false $(ObjFolder)$(Configuration).$(Platform)\$(AssemblyName) $(BinFolder)$(Configuration).$(Platform)\$(AssemblyName) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -