Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions CosmosDBShell/Azure.Data.Cosmos.Shell.Commands/CdCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public override async Task<CommandState> ExecuteAsync(ShellInterpreter shell, Co
// Handle "cd" with no arguments - go to root
if (targetDatabase == null && targetContainer == null)
{
SetState(shell, new ConnectedState(connectedState.Client));
SetState(shell, new ConnectedState(connectedState.Client, connectedState.ArmContext));
if (!this.Quiet)
{
ShellInterpreter.WriteLine(MessageService.GetString("command-cd-changed_to_connected_state"));
Expand All @@ -93,11 +93,11 @@ public override async Task<CommandState> ExecuteAsync(ShellInterpreter shell, Co
// Validate and navigate to database
if (targetDatabase != null)
{
await ValidateDatabaseExistsAsync(connectedState.Client, targetDatabase, "cd", token);
await ValidateDatabaseExistsAsync(connectedState, targetDatabase, "cd", token);

if (targetContainer == null)
{
SetState(shell, new DatabaseState(targetDatabase, connectedState.Client));
SetState(shell, new DatabaseState(targetDatabase, connectedState.Client, connectedState.ArmContext));
if (!this.Quiet)
{
ShellInterpreter.WriteLine(MessageService.GetString("command-cd-changed_to_db", new Dictionary<string, object> { { "db", targetDatabase } }));
Expand All @@ -108,8 +108,8 @@ public override async Task<CommandState> ExecuteAsync(ShellInterpreter shell, Co
}

// Continue to navigate to container
await ValidateContainerExistsAsync(connectedState.Client, targetDatabase, targetContainer, "cd", token);
SetState(shell, new ContainerState(targetContainer, targetDatabase, connectedState.Client));
await ValidateContainerExistsAsync(connectedState, targetDatabase, targetContainer, "cd", token);
SetState(shell, new ContainerState(targetContainer, targetDatabase, connectedState.Client, connectedState.ArmContext));
if (!this.Quiet)
{
ShellInterpreter.WriteLine(MessageService.GetString("command-cd-changed_to_container", new Dictionary<string, object> { { "container", targetContainer } }));
Expand All @@ -128,8 +128,8 @@ public override async Task<CommandState> ExecuteAsync(ShellInterpreter shell, Co
throw new NotInDatabaseException("cd");
}

await ValidateContainerExistsAsync(connectedState.Client, dbName, targetContainer, "cd", token);
SetState(shell, new ContainerState(targetContainer, dbName, connectedState.Client));
await ValidateContainerExistsAsync(connectedState, dbName, targetContainer, "cd", token);
SetState(shell, new ContainerState(targetContainer, dbName, connectedState.Client, connectedState.ArmContext));
if (!this.Quiet)
{
ShellInterpreter.WriteLine(MessageService.GetString("command-cd-changed_to_container", new Dictionary<string, object> { { "container", targetContainer } }));
Expand Down
17 changes: 16 additions & 1 deletion CosmosDBShell/Azure.Data.Cosmos.Shell.Commands/ConnectCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ internal partial class ConnectCommand : CosmosCommand
[CosmosOption("managed-identity")]
public string? ManagedIdentityClientId { get; set; }

[CosmosOption("subscription")]
public string? SubscriptionId { get; set; }

[CosmosOption("resource-group")]
public string? ResourceGroupName { get; set; }

[CosmosOption("account")]
public string? AccountName { get; set; }

[CosmosOption("vscode-credential", "connect-vscode-credential", Hidden = true)]
public bool UseVSCodeCredential { get; init; }

Expand Down Expand Up @@ -85,7 +94,7 @@ public async override Task<CommandState> ExecuteAsync(ShellInterpreter shell, Co

try
{
await shell.ConnectAsync(this.ConnectionString, this.LoginHint, connectionMode, tenantId: this.TenantId, authorityHost: this.AuthorityHost, managedIdentityClientId: this.ManagedIdentityClientId, useVSCodeCredential: this.UseVSCodeCredential, token: token);
await shell.ConnectAsync(this.ConnectionString, this.LoginHint, connectionMode, tenantId: this.TenantId, authorityHost: this.AuthorityHost, managedIdentityClientId: this.ManagedIdentityClientId, useVSCodeCredential: this.UseVSCodeCredential, subscriptionId: this.SubscriptionId, resourceGroupName: this.ResourceGroupName, accountName: this.AccountName, token: token);
var returnState = new CommandState
{
IsPrinted = true,
Expand Down Expand Up @@ -162,6 +171,11 @@ private static async Task<CommandState> PrintConnectionInfoAsync(ShellInterprete
table.AddRow(MessageService.GetString("command-connect-info-account"), $"[white]{acc.Id}[/]");
table.AddRow(MessageService.GetString("command-connect-info-endpoint"), $"[white]{client.Endpoint}[/]");

if (connectedState.ArmContext != null)
{
table.AddRow(MessageService.GetString("command-connect-info-arm-account"), $"[white]{connectedState.ArmContext.AccountResourceId}[/]");
}

// Display the connection mode
var connectionMode = client.ClientOptions.ConnectionMode;
table.AddRow(MessageService.GetString("command-connect-info-mode"), $"[white]{connectionMode}[/]");
Expand All @@ -183,6 +197,7 @@ private static async Task<CommandState> PrintConnectionInfoAsync(ShellInterprete
["connected"] = true,
["accountId"] = acc.Id,
["endpoint"] = client.Endpoint.ToString(),
["armAccountId"] = connectedState.ArmContext?.AccountResourceId.ToString(),
["connectionMode"] = connectionMode.ToString(),
["readRegions"] = acc.ReadableRegions.Select(r => r.Name).ToArray(),
["writeRegions"] = acc.WritableRegions.Select(r => r.Name).ToArray(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async Task<CommandState> IStateVisitor<CommandState, ShellInterpreter>.VisitConn
{
if (!string.IsNullOrEmpty(this.Database) && !string.IsNullOrEmpty(this.Container))
{
return await this.ExecuteOnContainerAsync(state.Client, this.Database, this.Container, token);
return await this.ExecuteOnContainerAsync(state, this.Database, this.Container, token);
}

throw new NotInContainerException("indexpolicy");
Expand All @@ -69,7 +69,7 @@ async Task<CommandState> IStateVisitor<CommandState, ShellInterpreter>.VisitData

if (!string.IsNullOrEmpty(this.Container))
{
return await this.ExecuteOnContainerAsync(state.Client, databaseName, this.Container, token);
return await this.ExecuteOnContainerAsync(state, databaseName, this.Container, token);
}

throw new NotInContainerException("indexpolicy");
Expand All @@ -80,20 +80,12 @@ async Task<CommandState> IStateVisitor<CommandState, ShellInterpreter>.VisitCont
string databaseName = this.Database ?? state.DatabaseName;
string containerName = this.Container ?? state.ContainerName;

return await this.ExecuteOnContainerAsync(state.Client, databaseName, containerName, token);
return await this.ExecuteOnContainerAsync(state, databaseName, containerName, token);
}

private static async Task<CommandState> ReadIndexPolicyAsync(Container container, CancellationToken token)
private static async Task<CommandState> ReadIndexPolicyAsync(ConnectedState state, string databaseName, string containerName, CancellationToken token)
{
var containerResponse = await container.ReadContainerAsync(cancellationToken: token);
var resource = containerResponse.Resource;
if (resource == null)
{
throw new CommandException("indexpolicy", MessageService.GetString("error-unable_to_read_container"));
}

var indexingPolicy = resource.IndexingPolicy;
var json = Newtonsoft.Json.JsonConvert.SerializeObject(indexingPolicy, Newtonsoft.Json.Formatting.Indented);
var json = await CosmosResourceFacade.GetIndexingPolicyJsonAsync(state, databaseName, containerName, token);

ShellInterpreter.WriteLine(json);

Expand All @@ -106,50 +98,34 @@ private static async Task<CommandState> ReadIndexPolicyAsync(Container container
return commandState;
}

private async Task<CommandState> ExecuteOnContainerAsync(CosmosClient client, string databaseName, string containerName, CancellationToken token)
private async Task<CommandState> ExecuteOnContainerAsync(ConnectedState state, string databaseName, string containerName, CancellationToken token)
{
await ValidateContainerExistsAsync(client, databaseName, containerName, "indexpolicy", token);

var container = client.GetDatabase(databaseName).GetContainer(containerName);
await ValidateContainerExistsAsync(state, databaseName, containerName, "indexpolicy", token);

if (string.IsNullOrEmpty(this.Policy))
{
return await ReadIndexPolicyAsync(container, token);
return await ReadIndexPolicyAsync(state, databaseName, containerName, token);
}

return await this.WriteIndexPolicyAsync(container, token);
return await this.WriteIndexPolicyAsync(state, databaseName, containerName, token);
}

private async Task<CommandState> WriteIndexPolicyAsync(Container container, CancellationToken token)
private async Task<CommandState> WriteIndexPolicyAsync(ConnectedState state, string databaseName, string containerName, CancellationToken token)
{
IndexingPolicy indexingPolicy;
string updatedJson;
try
{
indexingPolicy = Newtonsoft.Json.JsonConvert.DeserializeObject<IndexingPolicy>(this.Policy!)
?? throw new CommandException("indexpolicy", MessageService.GetString("command-indexpolicy-error_invalid_policy"));
updatedJson = await CosmosResourceFacade.ReplaceIndexingPolicyAsync(state, databaseName, containerName, this.Policy!, token);
}
catch (Newtonsoft.Json.JsonException ex)
catch (Exception ex) when (ex is JsonException or FormatException or InvalidOperationException)
{
throw new CommandException("indexpolicy", MessageService.GetString("command-indexpolicy-error_invalid_policy"), ex);
}

var containerResponse = await container.ReadContainerAsync(cancellationToken: token);
var resource = containerResponse.Resource;
if (resource == null)
{
throw new CommandException("indexpolicy", MessageService.GetString("error-unable_to_read_container"));
}

resource.IndexingPolicy = indexingPolicy;
var replaceResponse = await container.ReplaceContainerAsync(resource, cancellationToken: token);

var updatedPolicy = replaceResponse.Resource.IndexingPolicy;
var json = Newtonsoft.Json.JsonConvert.SerializeObject(updatedPolicy, Newtonsoft.Json.Formatting.Indented);

ShellInterpreter.WriteLine(MessageService.GetString("command-indexpolicy-updated"));
ShellInterpreter.WriteLine(json);
ShellInterpreter.WriteLine(updatedJson);

using var jsonDoc = JsonDocument.Parse(json);
using var jsonDoc = JsonDocument.Parse(updatedJson);
var commandState = new CommandState
{
IsPrinted = true,
Expand Down
46 changes: 23 additions & 23 deletions CosmosDBShell/Azure.Data.Cosmos.Shell.Commands/ListCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async Task<CommandState> IStateVisitor<CommandState, ShellInterpreter>.VisitCont
string databaseName = this.Database ?? state.DatabaseName;
string containerName = this.Container ?? state.ContainerName;

return await this.ListContainerItemsAsync(state.Client, databaseName, containerName, token);
return await this.ListContainerItemsAsync(state, databaseName, containerName, token);
}

Task<CommandState> IStateVisitor<CommandState, ShellInterpreter>.VisitDisconnectedStateAsync(DisconnectedState state, ShellInterpreter interpreter, CancellationToken token)
Expand All @@ -67,26 +67,26 @@ async Task<CommandState> IStateVisitor<CommandState, ShellInterpreter>.VisitConn
{
if (!string.IsNullOrEmpty(this.Container))
{
return await this.ListContainerItemsAsync(state.Client, this.Database, this.Container, token);
return await this.ListContainerItemsAsync(state, this.Database, this.Container, token);
}

return await this.ListDatabaseContainersAsync(state.Client, this.Database, token);
return await this.ListDatabaseContainersAsync(state, this.Database, token);
}

// Default behavior: list databases
var list = new List<string>();
var completionList = new List<string>();
await foreach (var database in EnumerateDatabasesAsync(state.Client))
await foreach (var databaseName in EnumerateDatabaseNamesAsync(state, "ls", token))
{
var databaseName = database.Id.Trim();
completionList.Add(databaseName);
var trimmed = databaseName.Trim();
completionList.Add(trimmed);

if (!this.IsMatch(database.Id))
if (!this.IsMatch(trimmed))
{
continue;
}

var cn = Markup.Escape(databaseName);
var cn = Markup.Escape(trimmed);
list.Add(cn);
AnsiConsole.MarkupLine($"[green]{cn}[/]");
}
Expand All @@ -108,36 +108,35 @@ async Task<CommandState> IStateVisitor<CommandState, ShellInterpreter>.VisitData
// If container is specified, list items in that container
if (!string.IsNullOrEmpty(this.Container))
{
return await this.ListContainerItemsAsync(state.Client, databaseName, this.Container, token);
return await this.ListContainerItemsAsync(state, databaseName, this.Container, token);
}

// Default behavior: list containers in the database
return await this.ListDatabaseContainersAsync(state.Client, databaseName, token);
return await this.ListDatabaseContainersAsync(state, databaseName, token);
}

private async Task<CommandState> ListDatabaseContainersAsync(CosmosClient client, string databaseName, CancellationToken token)
private async Task<CommandState> ListDatabaseContainersAsync(ConnectedState state, string databaseName, CancellationToken token)
{
// Validate database exists
await ValidateDatabaseExistsAsync(client, databaseName, "ls", token);
var db = client.GetDatabase(databaseName);
await ValidateDatabaseExistsAsync(state, databaseName, "ls", token);
var list = new List<string>();
var completionList = new List<string>();
await foreach (var container in EnumerateContainersAsync(db))
await foreach (var containerName in EnumerateContainerNamesAsync(state, databaseName, "ls", token))
{
var containerName = container.Id.Trim();
completionList.Add(containerName);
var trimmed = containerName.Trim();
completionList.Add(trimmed);

if (!this.IsMatch(container.Id))
if (!this.IsMatch(trimmed))
{
continue;
}

var cn = Markup.Escape(containerName);
var cn = Markup.Escape(trimmed);
list.Add(cn);
AnsiConsole.MarkupLine($"[magenta]{cn}[/]");
}

CosmosCompleteCommand.SetContainers(client, databaseName, completionList);
CosmosCompleteCommand.SetContainers(state.Client, databaseName, completionList);

var result = new CommandState
{
Expand All @@ -147,11 +146,12 @@ private async Task<CommandState> ListDatabaseContainersAsync(CosmosClient client
return result;
}

private async Task<CommandState> ListContainerItemsAsync(CosmosClient client, string databaseName, string containerName, CancellationToken token)
private async Task<CommandState> ListContainerItemsAsync(ConnectedState state, string databaseName, string containerName, CancellationToken token)
{
// Validate database and container exist
await ValidateContainerExistsAsync(client, databaseName, containerName, "ls", token);
await ValidateContainerExistsAsync(state, databaseName, containerName, "ls", token);

var client = state.Client;
var container = client.GetDatabase(databaseName).GetContainer(containerName);
AnsiConsole.MarkupLine(MessageService.GetString("command-ls-container", new Dictionary<string, object> { { "container", Theme.ContainerNamePromt(container.Id) } }));
var opt = new QueryRequestOptions();
Expand All @@ -161,8 +161,8 @@ private async Task<CommandState> ListContainerItemsAsync(CosmosClient client, st
opt.MaxItemCount = effectiveMaxItemCount.Value;
}

var containerResponse = await container.ReadContainerAsync(cancellationToken: token);
var partitionKeyPropertyNames = GetPartitionKeyPropertyNames(containerResponse.Resource.PartitionKeyPaths);
var partitionKeyPaths = await CosmosResourceFacade.GetPartitionKeyPathsAsync(state, databaseName, containerName, token);
var partitionKeyPropertyNames = GetPartitionKeyPropertyNames(partitionKeyPaths);
var matchKeyPropertyNames = string.IsNullOrEmpty(this.Key) ? partitionKeyPropertyNames : [this.Key];

var queryText = BuildItemQueryText(effectiveMaxItemCount, this.Filter);
Expand Down
Loading
Loading