Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions CosmosDBShell.Tests/CommandTests/ConnectCommandTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace CosmosShell.Tests.CommandTests;
using Azure.Data.Cosmos.Shell.Lsp.Semantics;
using Azure.Data.Cosmos.Shell.Parser;
using Microsoft.Azure.Cosmos;
using System.Net.Http;

public class ConnectCommandTests
{
Expand Down Expand Up @@ -68,6 +69,74 @@ public void ConnectCommand_VSCodeCredentialOption_DoesNotProduceUnknownOptionDia
Assert.DoesNotContain(model.Diagnostics, diagnostic => diagnostic.Code == "SEM002");
}

[Fact]
public async Task ConnectCommand_EmulatorOption_BindsFlag()
{
var command = await BindConnectCommandAsync("connect --emulator");

Assert.Null(command.ConnectionString);
Assert.True(command.Emulator);
}

[Fact]
public async Task ConnectCommand_EmulatorShortOption_BindsFlag()
{
var command = await BindConnectCommandAsync("connect -e");

Assert.True(command.Emulator);
}

[Fact]
public async Task ConnectCommand_EmulatorWithExplicitEndpoint_BindsBoth()
{
var command = await BindConnectCommandAsync("connect --emulator https://localhost:9000/");

Assert.Equal("https://localhost:9000/", command.ConnectionString);
Assert.True(command.Emulator);
}

[Fact]
public async Task ConnectAsync_EmulatorAgainstNonLocalEndpoint_Throws()
{
using var shell = ShellInterpreter.CreateInstance();

var ex = await Assert.ThrowsAsync<ShellException>(() => shell.ConnectAsync(
"https://contoso.documents.azure.com:443/",
forceEmulator: true,
token: CancellationToken.None));
Assert.Contains("emulator", ex.Message, StringComparison.OrdinalIgnoreCase);
}

[Fact]
public void IsTlsHandshakeFailure_DetectsAuthenticationException()
{
var ex = new HttpRequestException("send failed", new System.Security.Authentication.AuthenticationException("inner"));
Assert.True(ShellInterpreter.IsTlsHandshakeFailure(ex));
}

[Fact]
public void IsTlsHandshakeFailure_DetectsResetSocket()
{
var ex = new HttpRequestException("send failed", new System.Net.Sockets.SocketException((int)System.Net.Sockets.SocketError.ConnectionReset));
Assert.True(ShellInterpreter.IsTlsHandshakeFailure(ex));
}

[Fact]
public void IsTlsHandshakeFailure_IgnoresGenericHttpRequestException()
{
// No type-based marker => not a TLS handshake failure (was previously matched by the
// brittle "SSL" message substring check).
var ex = new HttpRequestException("Connection refused (SSL inside the message text only)");
Assert.False(ShellInterpreter.IsTlsHandshakeFailure(ex));
}

[Fact]
public void IsTlsHandshakeFailure_IgnoresUnrelatedSocketErrors()
{
var ex = new HttpRequestException("send failed", new System.Net.Sockets.SocketException((int)System.Net.Sockets.SocketError.HostUnreachable));
Assert.False(ShellInterpreter.IsTlsHandshakeFailure(ex));
}

private static async Task<ConnectCommand> BindConnectCommandAsync(string commandText)
{
var parser = new StatementParser(commandText);
Expand Down
22 changes: 22 additions & 0 deletions CosmosDBShell.Tests/UtilTest/ParseDocDBConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,28 @@ public void TestIsLocalEmulatorEndpoint_ConnectionStringWithoutKey()
Assert.True(ParsedDocDBConnectionString.IsLocalEmulatorEndpoint("AccountEndpoint=https://localhost:8081/;"));
}

[Theory]
[InlineData("https://notlocalhost.com/")]
[InlineData("https://localhost.contoso.com/")]
[InlineData("https://contoso.localhost.com/")]
[InlineData("https://10.127.0.0.1.example.com/")]
[InlineData("AccountEndpoint=https://notlocalhost.com/;")]
[InlineData("AccountEndpoint=https://contoso.documents.azure.com:443/;AccountKey=k;")]
public void TestIsLocalEmulatorEndpoint_NonLocalHostsAreRejected(string input)
{
Assert.False(ParsedDocDBConnectionString.IsLocalEmulatorEndpoint(input));
}

[Theory]
[InlineData("https://localhost:8081/")]
[InlineData("https://LOCALHOST:8081/")]
[InlineData("https://127.0.0.1:8081/")]
[InlineData("http://[::1]:8081/")]
public void TestIsLocalEmulatorEndpoint_LoopbackHostsAreAccepted(string input)
{
Assert.True(ParsedDocDBConnectionString.IsLocalEmulatorEndpoint(input));
}

[Fact]
public void TestPlainLocalhostUrlNotParsedAsConnectionString()
{
Expand Down
14 changes: 10 additions & 4 deletions CosmosDBShell/Azure.Data.Cosmos.Shell.Commands/ConnectCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace Azure.Data.Cosmos.Shell.Commands;
[CosmosCommand("connect")]
[CosmosExample("connect", Description = "Show current connection information and mode")]
[CosmosExample("connect \"AccountEndpoint=https://myaccount.documents.azure.com:443/;AccountKey=mykey;\"", Description = "Connect using connection string with account key")]
[CosmosExample("connect --emulator", Description = "Connect to the local Cosmos DB Emulator on https://localhost:8081 (falls back to HTTP if the TLS handshake fails)")]
[CosmosExample("connect https://localhost:8081", Description = "Connect to the local Cosmos DB Emulator (uses well-known key and gateway mode)")]
[CosmosExample("connect https://myaccount.documents.azure.com:443/ -hint=user@contoso.com", Description = "Connect using Entra ID authentication with login hint")]
[CosmosExample("connect https://myaccount.documents.azure.com:443/ -tenant=<tenant-id> -mode=gateway", Description = "Connect using Entra ID with gateway connection mode")]
Expand Down Expand Up @@ -46,10 +47,13 @@ internal partial class ConnectCommand : CosmosCommand
[CosmosOption("vscode-credential", "connect-vscode-credential", Hidden = true)]
public bool UseVSCodeCredential { get; init; }

[CosmosOption("emulator", "e")]
public bool Emulator { get; init; }

public async override Task<CommandState> ExecuteAsync(ShellInterpreter shell, CommandState commandState, string commandText, CancellationToken token)
{
// If no connection string provided, show current connection info
if (this.ConnectionString is null)
// If no connection string and not using --emulator, show current connection info
if (this.ConnectionString is null && !this.Emulator)
{
return await PrintConnectionInfoAsync(shell, commandState, token);
}
Expand Down Expand Up @@ -85,12 +89,14 @@ public async override Task<CommandState> ExecuteAsync(ShellInterpreter shell, Co

try
{
await shell.ConnectAsync(this.ConnectionString, this.LoginHint, connectionMode, tenantId: this.TenantId, authorityHost: this.AuthorityHost, managedIdentityClientId: this.ManagedIdentityClientId, useVSCodeCredential: this.UseVSCodeCredential, token: token);
await shell.ConnectAsync(this.ConnectionString, this.LoginHint, connectionMode, tenantId: this.TenantId, authorityHost: this.AuthorityHost, managedIdentityClientId: this.ManagedIdentityClientId, useVSCodeCredential: this.UseVSCodeCredential, forceEmulator: this.Emulator, token: token);
var returnState = new CommandState
{
IsPrinted = true,
};
var endpoint = ParsedDocDBConnectionString.ExtractEndpoint(this.ConnectionString);
var endpoint = this.ConnectionString is null
? null
: ParsedDocDBConnectionString.ExtractEndpoint(this.ConnectionString);
var resultElement = JsonSerializer.SerializeToElement(new Dictionary<string, string?>
{
["connected state"] = endpoint,
Expand Down
162 changes: 143 additions & 19 deletions CosmosDBShell/Azure.Data.Cosmos.Shell.Core/ShellInterpreter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ internal async Task<CommandState> RunCommandAsync(CommandState currentState, str
return currentState;
}

internal async Task ConnectAsync(string connectionString, string? loginHint = null, ConnectionMode? mode = null, string? tenantId = null, string? authorityHost = null, string? managedIdentityClientId = null, bool useVSCodeCredential = false, CancellationToken token = default)
internal async Task ConnectAsync(string? connectionString, string? loginHint = null, ConnectionMode? mode = null, string? tenantId = null, string? authorityHost = null, string? managedIdentityClientId = null, bool useVSCodeCredential = false, bool forceEmulator = false, CancellationToken token = default)
{
token.ThrowIfCancellationRequested();

Expand All @@ -583,8 +583,25 @@ internal async Task ConnectAsync(string connectionString, string? loginHint = nu
CosmosClient? client = null;

// Step 1: Resolve account key (from connection string, env variable, or emulator well-known key)
bool isEmulator = ParsedDocDBConnectionString.IsLocalEmulatorEndpoint(connectionString);
if (isEmulator)
if (forceEmulator)
{
if (string.IsNullOrWhiteSpace(connectionString))
{
connectionString = "https://localhost:8081/";
}
else if (!ParsedDocDBConnectionString.IsLocalEmulatorEndpoint(connectionString))
{
throw new ShellException(MessageService.GetString("command-connect-emulator-non_local"));
}
Comment thread
mkrueger marked this conversation as resolved.
}

if (string.IsNullOrWhiteSpace(connectionString))
{
throw new ShellException(MessageService.GetString("command-connect-error-no_endpoint"));
}

bool isEmulator = forceEmulator || ParsedDocDBConnectionString.IsLocalEmulatorEndpoint(connectionString);
if (isEmulator && !forceEmulator)
{
WriteLine(MessageService.GetString("command-connect-emulator-detected"));
}
Expand Down Expand Up @@ -625,26 +642,21 @@ internal async Task ConnectAsync(string connectionString, string? loginHint = nu
{
WriteLine(MessageService.GetString("shell-connect-key-auth"));
var keyMode = mode ?? (isEmulator ? ConnectionMode.Gateway : ConnectionMode.Direct);
var keyOptions = CreateClientOptions(connectionString, keyMode);
client = new CosmosClient(connectionString, keyOptions);

AccountProperties keyProps;
try
{
keyProps = await ReadAccountAsync(client, token);
}
catch (OperationCanceledException) when (token.IsCancellationRequested)
{
client.Dispose();
throw;
}
catch (Exception ex)
(CosmosClient connectedClient, AccountProperties keyProps, string finalEndpoint) = await ConnectWithAccountKeyAsync(
connectionString,
keyMode,
isEmulator,
token);
client = connectedClient;

WriteLine(MessageService.GetArgsString("command-connect-connected", "account", keyProps.Id));

if (isEmulator)
{
client.Dispose();
throw new ShellException(MessageService.GetString("error-connection_failed"), ex);
ReportEmulatorProtocol(finalEndpoint);
}

WriteLine(MessageService.GetArgsString("command-connect-connected", "account", keyProps.Id));
this.Connect(client);
return;
}
Expand Down Expand Up @@ -903,6 +915,118 @@ private static async Task<AccountProperties> ReadAccountAsync(CosmosClient clien
return await client.ReadAccountAsync().WaitAsync(token);
}

/// <summary>
/// Connects with an account key, with an emulator-only HTTPS to HTTP fallback when the
/// underlying TLS handshake fails. Returns the live client, the account properties, and
/// the endpoint that was actually used.
/// </summary>
private static async Task<(CosmosClient Client, AccountProperties Properties, string Endpoint)> ConnectWithAccountKeyAsync(
string connectionString,
ConnectionMode keyMode,
bool isEmulator,
CancellationToken token)
{
var endpoint = ParsedDocDBConnectionString.ExtractEndpoint(connectionString);
var keyOptions = CreateClientOptions(connectionString, keyMode);
var client = new CosmosClient(connectionString, keyOptions);

try
{
var properties = await ReadAccountAsync(client, token);
return (client, properties, endpoint);
}
catch (OperationCanceledException) when (token.IsCancellationRequested)
{
client.Dispose();
throw;
}
catch (Exception ex)
{
client.Dispose();

if (isEmulator && IsTlsHandshakeFailure(ex) &&
Uri.TryCreate(endpoint, UriKind.Absolute, out var endpointUri) &&
endpointUri.Scheme == Uri.UriSchemeHttps)
{
var httpEndpoint = new UriBuilder(endpointUri) { Scheme = Uri.UriSchemeHttp, Port = endpointUri.Port }.Uri.ToString();
WriteLine(MessageService.GetString("command-connect-emulator-https_failed"));

var fallbackConnectionString = ParsedDocDBConnectionString.BuildEmulatorConnectionString(
httpEndpoint,
ParsedDocDBConnectionString.TryParseDocDBConnectionString(connectionString, out var parsed) ? parsed?.MasterKey : null);
var fallbackOptions = CreateClientOptions(fallbackConnectionString, keyMode);
var fallbackClient = new CosmosClient(fallbackConnectionString, fallbackOptions);
try
{
var fallbackProperties = await ReadAccountAsync(fallbackClient, token);
return (fallbackClient, fallbackProperties, httpEndpoint);
}
catch (OperationCanceledException) when (token.IsCancellationRequested)
{
fallbackClient.Dispose();
throw;
}
catch (Exception fallbackEx)
{
fallbackClient.Dispose();
var aggregated = new AggregateException(
MessageService.GetString("command-connect-emulator-fallback-failed"),
ex,
fallbackEx);
throw new ShellException(MessageService.GetString("error-connection_failed"), aggregated);
}
}

throw new ShellException(MessageService.GetString("error-connection_failed"), ex);
}
}
Comment thread
mkrueger marked this conversation as resolved.

/// <summary>
/// Detects TLS handshake / certificate-validation failures in the exception chain. Used to
/// decide whether an emulator HTTPS attempt should fall back to HTTP. Uses type-based checks
/// so the decision is not affected by localized exception messages.
/// </summary>
internal static bool IsTlsHandshakeFailure(Exception ex)
{
for (var current = ex; current != null; current = current.InnerException)
{
switch (current)
{
case System.Security.Authentication.AuthenticationException:
case System.Security.Cryptography.CryptographicException:
return true;
case System.Net.Sockets.SocketException socketEx
when socketEx.SocketErrorCode is
System.Net.Sockets.SocketError.ConnectionReset or
System.Net.Sockets.SocketError.ConnectionAborted:
// The TLS layer typically surfaces handshake failures from a non-TLS server
// as a reset/aborted socket while the request is still in the TLS handshake
// phase, before any HTTP exchange has happened.
return true;
}
Comment thread
mkrueger marked this conversation as resolved.
}

return false;
}

private static void ReportEmulatorProtocol(string endpoint)
{
if (!Uri.TryCreate(endpoint, UriKind.Absolute, out var uri))
{
return;
}

if (uri.Scheme == Uri.UriSchemeHttps)
{
WriteLine(MessageService.GetArgsString("command-connect-emulator-using_https", "endpoint", endpoint));
}
else
{
WriteLine(MessageService.GetArgsString("command-connect-emulator-using_http", "endpoint", endpoint));
WriteLine(MessageService.GetString("command-connect-emulator-http_tip"));
}
}

/// <summary>
/// Connects to a client & disposes old state.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public class LocalizableSentenceBuilder : SentenceBuilder

public static string ConnectVSCodeCredential => MessageService.GetString("help-ConnectVSCodeCredential");

public static string ConnectEmulator => MessageService.GetString("help-ConnectEmulator");

public static string Command => MessageService.GetString("help-cmd");

public static string EnableMcpServer => MessageService.GetString("help-EnableMcpServer");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,33 @@ public static bool IsLocalEmulatorEndpoint(string? connectionStringOrEndpoint)
return false;
}

return connectionStringOrEndpoint.Contains("localhost", StringComparison.OrdinalIgnoreCase)
|| connectionStringOrEndpoint.Contains("127.0.0.1", StringComparison.OrdinalIgnoreCase);
// Resolve to a clean endpoint URL whether the input is a full connection string
// ("AccountEndpoint=...;...") or a plain URL.
string? endpoint;
if (TryParseDocDBConnectionString(connectionStringOrEndpoint, out var parsed))
{
endpoint = parsed!.Endpoint;
}
else if (IsPlainUrl(connectionStringOrEndpoint))
{
endpoint = connectionStringOrEndpoint;
}
else
{
return false;
}

if (!Uri.TryCreate(endpoint, UriKind.Absolute, out var uri))
{
return false;
}

if (string.Equals(uri.Host, "localhost", StringComparison.OrdinalIgnoreCase))
{
return true;
}

return System.Net.IPAddress.TryParse(uri.Host, out var ip) && System.Net.IPAddress.IsLoopback(ip);
}

/// <summary>
Expand Down
Loading
Loading