diff --git a/README.md b/README.md index 2a5dab38..eda8e39e 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ these are the supported agents: Agent README files: +- [`az-cli` & `az-ps`][13] - [`openai-gpt`][08] - [`ollama`][06] - [`interpreter`][07] @@ -158,3 +159,4 @@ bugs, suggestions, or feedback. [10]: https://github.com/PowerShell/ProjectMercury/issues [11]: https://learn.microsoft.com/powershell/scripting/install/installing-powershell [12]: ./docs/SECURITY.md +[13]: ./shell/agents/AIShell.Azure.Agent/README.md diff --git a/build.psm1 b/build.psm1 index b1516c6f..495c3fe6 100644 --- a/build.psm1 +++ b/build.psm1 @@ -20,7 +20,7 @@ function Start-Build [string] $Runtime = [NullString]::Value, [Parameter()] - [ValidateSet('openai-gpt', 'interpreter', 'ollama')] + [ValidateSet('openai-gpt', 'az-agent', 'interpreter', 'ollama')] [string[]] $AgentToInclude, [Parameter()] @@ -40,7 +40,7 @@ function Start-Build if (-not $AgentToInclude) { $agents = $metadata.AgentsToInclude $AgentToInclude = if ($agents -eq "*") { - @('openai-gpt', 'interpreter', 'ollama') + @('openai-gpt', 'az-agent', 'interpreter', 'ollama') } else { $agents.Split(",", [System.StringSplitOptions]::TrimEntries) Write-Verbose "Include agents specified in Metadata.json" @@ -63,6 +63,7 @@ function Start-Build $module_dir = Join-Path $shell_dir "AIShell.Integration" $openai_agent_dir = Join-Path $agent_dir "AIShell.OpenAI.Agent" + $az_agent_dir = Join-Path $agent_dir "AIShell.Azure.Agent" $interpreter_agent_dir = Join-Path $agent_dir "AIShell.Interpreter.Agent" $ollama_agent_dir = Join-Path $agent_dir "AIShell.Ollama.Agent" @@ -73,6 +74,7 @@ function Start-Build $module_help_dir= Join-Path $PSScriptRoot "docs" "cmdlets" $openai_out_dir = Join-Path $app_out_dir "agents" "AIShell.OpenAI.Agent" + $az_out_dir = Join-Path $app_out_dir "agents" "AIShell.Azure.Agent" $interpreter_out_dir = Join-Path $app_out_dir "agents" "AIShell.Interpreter.Agent" $ollama_out_dir = Join-Path $app_out_dir "agents" "AIShell.Ollama.Agent" @@ -93,6 +95,12 @@ function Start-Build dotnet publish $openai_csproj -c $Configuration -o $openai_out_dir } + if ($LASTEXITCODE -eq 0 -and $AgentToInclude -contains 'az-agent') { + Write-Host "`n[Build the Azure agents ...]`n" -ForegroundColor Green + $az_csproj = GetProjectFile $az_agent_dir + dotnet publish $az_csproj -c $Configuration -o $az_out_dir + } + if ($LASTEXITCODE -eq 0 -and $AgentToInclude -contains 'interpreter') { Write-Host "`n[Build the Interpreter agent ...]`n" -ForegroundColor Green $interpreter_csproj = GetProjectFile $interpreter_agent_dir diff --git a/docs/FAQ.md b/docs/FAQ.md index e1c54e0d..d1e41011 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -18,6 +18,7 @@ agents: Agent README files: +- [`az-cli` & `az-ps`][05] - [`openai-gpt`][04] - [`ollama`][02] - [`interpreter`][03] @@ -44,3 +45,4 @@ documentation for your terminal application to see if it supports this feature. [02]: ../shell/agents/AIShell.Ollama.Agent/README.md [03]: ../shell/agents/AIShell.Interpreter.Agent/README.md [04]: ../shell/agents/AIShell.OpenAI.Agent/README.md +[05]: ../shell/agents/AIShell.Azure.Agent/README.md diff --git a/shell/AIShell.Abstraction/IHost.cs b/shell/AIShell.Abstraction/IHost.cs index fcdc20b9..8cbfbe41 100644 --- a/shell/AIShell.Abstraction/IHost.cs +++ b/shell/AIShell.Abstraction/IHost.cs @@ -93,6 +93,12 @@ public interface IHost /// Label and value pairs to be rendered for the object. void RenderList(T source, IList> elements); + /// + /// Render a divider with the passed-in text. + /// + /// A brief caption for the subsequent section. + void RenderDivider(string text); + /// /// Run an asynchronouse task with a spinner on the console showing the task in progress. /// @@ -100,7 +106,7 @@ public interface IHost /// The asynchronouse task. /// The status message to be shown. /// The returned result of - Task RunWithSpinnerAsync(Func> func, string status); + Task RunWithSpinnerAsync(Func> func, string status, SpinnerKind? spinnerKind); /// /// Run an asynchronouse task with a spinner on the console showing the task in progress. @@ -110,12 +116,22 @@ public interface IHost /// The asynchronouse task, which can change the status of the spinner. /// The initial status message to be shown. /// The returned result of - Task RunWithSpinnerAsync(Func> func, string status); + Task RunWithSpinnerAsync(Func> func, string status, SpinnerKind? spinnerKind); + + /// + /// Run an asynchronouse task with the default spinner and the default status message. + /// + Task RunWithSpinnerAsync(Func> func) => RunWithSpinnerAsync(func, "Generating...", SpinnerKind.Generating); /// - /// Run an asynchronouse task with a spinner with the default status message. + /// Run an asynchronouse task with the default spinner and the specified status message. /// - Task RunWithSpinnerAsync(Func> func) => RunWithSpinnerAsync(func, "Generating..."); + Task RunWithSpinnerAsync(Func> func, string status) => RunWithSpinnerAsync(func, status, SpinnerKind.Generating); + + /// + /// Run an asynchronouse task that allows changing the status message with the default spinner and the specified initial status message. + /// + Task RunWithSpinnerAsync(Func> func, string status) => RunWithSpinnerAsync(func, status, SpinnerKind.Generating); /// /// Prompt for selection asynchronously. @@ -173,9 +189,9 @@ Task PromptForTextAsync(string prompt, bool optional, CancellationToken /// Prompt for the user to input the value for an argument placeholder. /// /// Information about the argument placeholder. - /// Token to cancel operation. + /// Indicates if the caption, such as the description and restriction, should be printed. /// - string PromptForArgument(ArgumentInfo argInfo, CancellationToken cancellationToken); + string PromptForArgument(ArgumentInfo argInfo, bool printCaption); } /// @@ -189,19 +205,38 @@ public interface IStatusContext void Status(string status); } +/// +/// Enum type for the kind of spinner to use. +/// +public enum SpinnerKind +{ + /// + /// This spinner indicates text is being generated. + /// It should be used when generating response in chat. + /// This is the default spinner kind used by the host. + /// + Generating, + + /// + /// This spinner indicates a general task processing. + /// It should be used in all other cases, such as loading data, etc. + /// + Processing, +} + /// /// Information about an argument placeholder. /// -public sealed class ArgumentInfo +public class ArgumentInfo { /// /// Type of the argument data. /// public enum DataType { - String, - Int, - Bool, + @string, + @int, + @bool, } /// @@ -224,18 +259,13 @@ public enum DataType /// public DataType Type { get; } - /// - /// Gets a value indicating whether the user must choose from the suggestions. - /// - public bool MustChooseFromSuggestions { get; } - /// /// Gets the list of suggestions for the argument. /// public IList Suggestions { get; } public ArgumentInfo(string name, string description, DataType dataType) - : this(name, description, restriction: null, dataType, mustChooseFromSuggestions: false, suggestions: null) + : this(name, description, restriction: null, dataType, suggestions: null) { } @@ -244,24 +274,15 @@ public ArgumentInfo( string description, string restriction, DataType dataType, - bool mustChooseFromSuggestions, IList suggestions) { ArgumentException.ThrowIfNullOrEmpty(name); ArgumentException.ThrowIfNullOrEmpty(description); - if (mustChooseFromSuggestions && (suggestions is null || suggestions.Count < 2)) - { - throw new ArgumentException( - $"A suggestion list with at least 2 items is required when '{nameof(MustChooseFromSuggestions)}' is true.", - nameof(suggestions)); - } - Name = name; Description = description; Restriction = restriction; Type = dataType; - MustChooseFromSuggestions = mustChooseFromSuggestions; Suggestions = suggestions; } } diff --git a/shell/AIShell.Kernel/Host.cs b/shell/AIShell.Kernel/Host.cs index bfed03f9..a6b56e97 100644 --- a/shell/AIShell.Kernel/Host.cs +++ b/shell/AIShell.Kernel/Host.cs @@ -310,7 +310,18 @@ public void RenderList(T source, IList> elements) } /// - public async Task RunWithSpinnerAsync(Func> func, string status = null) + public void RenderDivider(string text) + { + ArgumentException.ThrowIfNullOrEmpty(text); + RequireStdoutOrStderr(operation: "render divider"); + + AnsiConsole.Write(new Rule($"[yellow]{text.EscapeMarkup()}[/]") + .RuleStyle("grey") + .LeftJustified()); + } + + /// + public async Task RunWithSpinnerAsync(Func> func, string status = null, SpinnerKind? spinnerKind = null) { if (_outputRedirected && _errorRedirected) { @@ -333,7 +344,7 @@ public async Task RunWithSpinnerAsync(Func> func, string status = return await ansiConsole .Status() .AutoRefresh(true) - .Spinner(AsciiLetterSpinner.Default) + .Spinner(GetSpinner(spinnerKind)) .SpinnerStyle(new Style(Color.Olive)) .StartAsync( $"[italic slowblink]{status.EscapeMarkup()}[/]", @@ -347,7 +358,7 @@ public async Task RunWithSpinnerAsync(Func> func, string status = } /// - public async Task RunWithSpinnerAsync(Func> func, string status) + public async Task RunWithSpinnerAsync(Func> func, string status, SpinnerKind? spinnerKind = null) { if (_outputRedirected && _errorRedirected) { @@ -370,7 +381,7 @@ public async Task RunWithSpinnerAsync(Func> func, return await ansiConsole .Status() .AutoRefresh(true) - .Spinner(AsciiLetterSpinner.Default) + .Spinner(GetSpinner(spinnerKind)) .SpinnerStyle(new Style(Color.Olive)) .StartAsync( $"[italic slowblink]{status.EscapeMarkup()}[/]", @@ -468,50 +479,45 @@ public async Task PromptForTextAsync(string prompt, bool optional, IList } /// - public string PromptForArgument(ArgumentInfo argInfo, CancellationToken cancellationToken) + public string PromptForArgument(ArgumentInfo argInfo, bool printCaption) { - WriteLine($"{argInfo.Name}: {argInfo.Description}."); - if (!string.IsNullOrEmpty(argInfo.Restriction)) + if (printCaption) { - WriteLine(argInfo.Restriction); - } + WriteLine(argInfo.Type is ArgumentInfo.DataType.@string + ? argInfo.Description + : $"{argInfo.Description}. Value type: {argInfo.Type}"); - if (argInfo.Type is ArgumentInfo.DataType.Bool || argInfo.Suggestions?.Count is 2) - { - return PromptForTextAsync( - prompt: ">", - optional: false, - choices: argInfo.Suggestions ?? ["ture", "flase"], - cancellationToken: cancellationToken).GetAwaiter().GetResult(); + if (!string.IsNullOrEmpty(argInfo.Restriction)) + { + WriteLine(argInfo.Restriction); + } } - if (argInfo.MustChooseFromSuggestions) + var suggestions = argInfo.Suggestions; + if (argInfo.Type is ArgumentInfo.DataType.@bool) { - string value = PromptForSelectionAsync( - title: "Choose the value from the below list:", - choices: argInfo.Suggestions, - cancellationToken: cancellationToken).GetAwaiter().GetResult(); - WriteLine($"> {value}"); - return value; + suggestions ??= ["ture", "flase"]; } var options = PSConsoleReadLine.GetOptions(); + var oldAddToHistoryHandler = options.AddToHistoryHandler; var oldReadLineHelper = options.ReadLineHelper; var oldPredictionView = options.PredictionViewStyle; var oldPredictionSource = options.PredictionSource; var newOptions = new SetPSReadLineOption { - ReadLineHelper = new PromptHelper(argInfo.Suggestions), + AddToHistoryHandler = c => AddToHistoryOption.SkipAdding, + ReadLineHelper = new PromptHelper(suggestions), PredictionSource = PredictionSource.Plugin, PredictionViewStyle = PredictionViewStyle.ListView, }; try { - Write("> "); + Markup($"[lime]{argInfo.Name}[/]: "); PSConsoleReadLine.SetOptions(newOptions); - string value = PSConsoleReadLine.ReadLine(); + string value = PSConsoleReadLine.ReadLine(CancellationToken.None); if (Console.CursorLeft is not 0) { // Ctrl+c was pressed by the user. @@ -523,6 +529,7 @@ public string PromptForArgument(ArgumentInfo argInfo, CancellationToken cancella } finally { + newOptions.AddToHistoryHandler = oldAddToHistoryHandler; newOptions.ReadLineHelper = oldReadLineHelper; newOptions.PredictionSource = oldPredictionSource; newOptions.PredictionViewStyle = oldPredictionView; @@ -549,6 +556,15 @@ internal void RenderReferenceText(string header, string content) AnsiConsole.WriteLine(); } + private static Spinner GetSpinner(SpinnerKind? kind) + { + return kind switch + { + SpinnerKind.Processing => Spinner.Known.Default, + _ => AsciiLetterSpinner.Default, + }; + } + /// /// Throw exception if standard input is redirected. /// diff --git a/shell/AIShell.Kernel/Shell.cs b/shell/AIShell.Kernel/Shell.cs index b7664ff7..39ec2fcb 100644 --- a/shell/AIShell.Kernel/Shell.cs +++ b/shell/AIShell.Kernel/Shell.cs @@ -604,7 +604,7 @@ internal async Task RunREPLAsync() { // Write out the remote query, in the same style as user typing. Host.Markup($"\n>> Remote Query Received:\n"); - Host.MarkupLine($"[teal]{input}[/]"); + Host.MarkupLine($"[teal]{input.EscapeMarkup()}[/]"); } else { @@ -736,7 +736,7 @@ Task find_agent_op() => orchestrator.FindAgentForPrompt( } Host.WriteErrorLine() - .WriteErrorLine($"Agent failed to generate a response: {ex.Message}") + .WriteErrorLine($"Agent failed to generate a response: {ex.Message}\n{ex.StackTrace}") .WriteErrorLine(); } } diff --git a/shell/AIShell.Kernel/Utility/ReadLineHelper.cs b/shell/AIShell.Kernel/Utility/ReadLineHelper.cs index e577e10d..7b8f67b5 100644 --- a/shell/AIShell.Kernel/Utility/ReadLineHelper.cs +++ b/shell/AIShell.Kernel/Utility/ReadLineHelper.cs @@ -371,7 +371,7 @@ internal class PromptHelper : IReadLineHelper internal PromptHelper(IList candidates) { _candidates = candidates; - _predictorName = "completion"; + _predictorName = "suggestion"; _predictorId = new Guid(GUID); } diff --git a/shell/agents/AIShell.Azure.Agent/AIShell.Azure.Agent.csproj b/shell/agents/AIShell.Azure.Agent/AIShell.Azure.Agent.csproj new file mode 100644 index 00000000..4f7d1834 --- /dev/null +++ b/shell/agents/AIShell.Azure.Agent/AIShell.Azure.Agent.csproj @@ -0,0 +1,32 @@ + + + + net8.0 + enable + true + + + false + + + + + false + None + + + + + + + + + + + false + + runtime + + + + diff --git a/shell/agents/AIShell.Azure.Agent/AzCLI/AzCLIAgent.cs b/shell/agents/AIShell.Azure.Agent/AzCLI/AzCLIAgent.cs new file mode 100644 index 00000000..fdcc6231 --- /dev/null +++ b/shell/agents/AIShell.Azure.Agent/AzCLI/AzCLIAgent.cs @@ -0,0 +1,265 @@ +using System.Diagnostics; +using System.Text; +using System.Text.Json; +using Azure.Identity; +using AIShell.Abstraction; + +namespace AIShell.Azure.CLI; + +public sealed class AzCLIAgent : ILLMAgent +{ + public string Name => "az-cli"; + public string Description => "This AI assistant can help generate Azure CLI scripts or commands for managing Azure resources and end-to-end scenarios that involve multiple different Azure resources."; + public string Company => "Microsoft"; + public List SampleQueries => [ + "Create a VM with a public IP address", + "How to create a web app?", + "Backup an Azure SQL database to a storage container" + ]; + public Dictionary LegalLinks { private set; get; } = null; + public string SettingFile { private set; get; } = null; + internal ArgumentPlaceholder ArgPlaceholder { set; get; } + internal UserValueStore ValueStore { get; } = new(); + + private const string SettingFileName = "az-cli.agent.json"; + private readonly Stopwatch _watch = new(); + + private AzCLIChatService _chatService; + private StringBuilder _text; + private MetricHelper _metricHelper; + private LinkedList _historyForTelemetry; + + public void Dispose() + { + _chatService?.Dispose(); + } + + public void Initialize(AgentConfig config) + { + _text = new StringBuilder(); + _chatService = new AzCLIChatService(); + _historyForTelemetry = []; + _metricHelper = new MetricHelper(AzCLIChatService.Endpoint); + + LegalLinks = new(StringComparer.OrdinalIgnoreCase) + { + ["Terms"] = "https://aka.ms/TermsofUseCopilot", + ["Privacy"] = "https://aka.ms/privacy", + ["FAQ"] = "https://aka.ms/CopilotforAzureClientToolsFAQ", + ["Transparency"] = "https://aka.ms/CopilotAzCLIPSTransparency", + }; + + SettingFile = Path.Combine(config.ConfigurationRoot, SettingFileName); + } + + public void RefreshChat() + { + // Reset the history so the subsequent chat can start fresh. + _chatService.ChatHistory.Clear(); + ArgPlaceholder = null; + ValueStore.Clear(); + } + + public IEnumerable GetCommands() => [new ReplaceCommand(this)]; + + public bool CanAcceptFeedback(UserAction action) => !MetricHelper.TelemetryOptOut; + + public void OnUserAction(UserActionPayload actionPayload) + { + // Send telemetry about the user action. + // DisLike Action + string DetailedMessage = null; + LinkedList history = null; + if (actionPayload.Action == UserAction.Dislike) + { + DislikePayload dislikePayload = (DislikePayload)actionPayload; + DetailedMessage = string.Format("{0} | {1}", dislikePayload.ShortFeedback, dislikePayload.LongFeedback); + if (dislikePayload.ShareConversation) + { + history = _historyForTelemetry; + } + else + { + _historyForTelemetry.Clear(); + } + } + // Like Action + else if (actionPayload.Action == UserAction.Like) + { + LikePayload likePayload = (LikePayload)actionPayload; + if (likePayload.ShareConversation) + { + history = _historyForTelemetry; + } + else + { + _historyForTelemetry.Clear(); + } + } + + _metricHelper.LogTelemetry( + new AzTrace() + { + Command = actionPayload.Action.ToString(), + CorrelationID = _chatService.CorrelationID, + EventType = "Feedback", + Handler = "Azure CLI", + DetailedMessage = DetailedMessage, + HistoryMessage = history + }); + } + + public async Task Chat(string input, IShell shell) + { + // Measure time spent + _watch.Restart(); + var startTime = DateTime.Now; + + IHost host = shell.Host; + CancellationToken token = shell.CancellationToken; + + try + { + AzCliResponse azResponse = await host.RunWithSpinnerAsync( + status: "Thinking ...", + func: async context => await _chatService.GetChatResponseAsync(context, input, token) + ).ConfigureAwait(false); + + if (azResponse is not null) + { + if (azResponse.Error is not null) + { + host.WriteLine($"\n{azResponse.Error}\n"); + return true; + } + + ResponseData data = azResponse.Data; + AddMessageToHistory( + JsonSerializer.Serialize(data, Utils.JsonOptions), + fromUser: false); + + string answer = GenerateAnswer(input, data); + host.RenderFullResponse(answer); + + // Measure time spent + _watch.Stop(); + + if (!MetricHelper.TelemetryOptOut) + { + // TODO: extract into RecordQuestionTelemetry() : RecordTelemetry() + var EndTime = DateTime.Now; + var Duration = TimeSpan.FromTicks(_watch.ElapsedTicks); + + // Append last Q&A history in HistoryMessage + _historyForTelemetry.AddLast(new HistoryMessage("user", input, _chatService.CorrelationID)); + _historyForTelemetry.AddLast(new HistoryMessage("assistant", answer, _chatService.CorrelationID)); + + _metricHelper.LogTelemetry( + new AzTrace() + { + CorrelationID = _chatService.CorrelationID, + Duration = Duration, + EndTime = EndTime, + EventType = "Question", + Handler = "Azure CLI", + StartTime = startTime + }); + } + } + } + catch (RefreshTokenException ex) + { + Exception inner = ex.InnerException; + if (inner is CredentialUnavailableException) + { + host.WriteErrorLine($"Access token not available. Query cannot be served."); + host.WriteErrorLine($"The '{Name}' agent depends on the Azure CLI credential to acquire access token. Please run 'az login' from a command-line shell to setup account."); + } + else + { + host.WriteErrorLine($"Failed to get the access token. {inner.Message}"); + } + + return false; + } + finally + { + // Stop the watch in case of early return or exception. + _watch.Stop(); + } + + return true; + } + + internal string GenerateAnswer(string input, ResponseData data) + { + _text.Clear(); + _text.Append(data.Description).Append("\n\n"); + + // We keep 'ArgPlaceholder' unchanged when it's re-generating in '/replace' with only partial placeholders replaced. + if (!ReferenceEquals(ArgPlaceholder?.ResponseData, data) || data.PlaceholderSet is null) + { + ArgPlaceholder?.DataRetriever?.Dispose(); + ArgPlaceholder = null; + } + + if (data.CommandSet.Count > 0) + { + // AzCLI handler incorrectly include pseudo values in the placeholder set, so we need to filter them out. + UserValueStore.FilterOutPseudoValues(data); + if (data.PlaceholderSet?.Count > 0) + { + // Create the data retriever for the placeholders ASAP, so it gets + // more time to run in background. + ArgPlaceholder ??= new ArgumentPlaceholder(input, data); + } + + for (int i = 0; i < data.CommandSet.Count; i++) + { + CommandItem action = data.CommandSet[i]; + // Replace the pseudo values with the real values. + string script = ValueStore.ReplacePseudoValues(action.Script); + + _text.Append($"{i+1}. {action.Desc}") + .Append("\n\n") + .Append("```sh\n") + .Append($"# {action.Desc}\n") + .Append(script).Append('\n') + .Append("```\n\n"); + } + + if (ArgPlaceholder is not null) + { + _text.Append("Please provide values for the following placeholder variables:\n\n"); + + for (int i = 0; i < data.PlaceholderSet.Count; i++) + { + PlaceholderItem item = data.PlaceholderSet[i]; + _text.Append($"- `{item.Name}`: {item.Desc}\n"); + } + + _text.Append("\nRun `/replace` to get assistance in placeholder replacement.\n"); + } + } + + return _text.ToString(); + } + + internal void AddMessageToHistory(string message, bool fromUser) + { + if (!string.IsNullOrEmpty(message)) + { + var history = _chatService.ChatHistory; + while (history.Count > Utils.HistoryCount - 1) + { + history.RemoveAt(0); + } + + history.Add(new ChatMessage() + { + Role = fromUser ? "user" : "assistant", + Content = message + }); + } + } +} diff --git a/shell/agents/AIShell.Azure.Agent/AzCLI/AzCLIChatService.cs b/shell/agents/AIShell.Azure.Agent/AzCLI/AzCLIChatService.cs new file mode 100644 index 00000000..c2b39930 --- /dev/null +++ b/shell/agents/AIShell.Azure.Agent/AzCLI/AzCLIChatService.cs @@ -0,0 +1,131 @@ +using System.Net; +using System.Net.Http.Headers; +using System.Text; +using System.Text.Json; +using Azure.Core; +using Azure.Identity; +using AIShell.Abstraction; + +namespace AIShell.Azure.CLI; + +internal class AzCLIChatService : IDisposable +{ + internal const string Endpoint = "https://azclitools-copilot-apim-temp.azure-api.net/azcli/copilot"; + + private readonly HttpClient _client; + private readonly string[] _scopes; + private readonly List _chatHistory; + private AccessToken? _accessToken; + private string _correlationID; + + internal string CorrelationID => _correlationID; + + internal AzCLIChatService() + { + _client = new HttpClient(); + _scopes = ["https://management.core.windows.net/"]; + _chatHistory = []; + _accessToken = null; + _correlationID = null; + } + + internal List ChatHistory => _chatHistory; + + public void Dispose() + { + _client.Dispose(); + } + + private string NewCorrelationID() + { + _correlationID = Guid.NewGuid().ToString(); + return _correlationID; + } + + private void RefreshToken(CancellationToken cancellationToken) + { + try + { + bool needRefresh = !_accessToken.HasValue; + if (!needRefresh) + { + needRefresh = DateTimeOffset.UtcNow + TimeSpan.FromMinutes(2) > _accessToken.Value.ExpiresOn; + } + + if (needRefresh) + { + _accessToken = new AzureCliCredential() + .GetToken(new TokenRequestContext(_scopes), cancellationToken); + } + } + catch (Exception e) when (e is not OperationCanceledException) + { + throw new RefreshTokenException("Failed to refresh the Azure CLI login token", e); + } + } + + private HttpRequestMessage PrepareForChat(string input) + { + _chatHistory.Add(new ChatMessage() { Role = "user", Content = input }); + + var requestData = new Query { Messages = _chatHistory }; + var json = JsonSerializer.Serialize(requestData, Utils.JsonOptions); + + var content = new StringContent(json, Encoding.UTF8, "application/json"); + var request = new HttpRequestMessage(HttpMethod.Post, Endpoint) { Content = content }; + + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _accessToken.Value.Token); + + // These headers are for telemetry. We refresh correlation ID for each query. + request.Headers.Add("CorrelationId", NewCorrelationID()); + request.Headers.Add("ClientType", "Copilot for client tools"); + + return request; + } + + internal async Task GetChatResponseAsync(IStatusContext context, string input, CancellationToken cancellationToken) + { + try + { + context?.Status("Refreshing Token ..."); + RefreshToken(cancellationToken); + + context?.Status("Generating ..."); + HttpRequestMessage request = PrepareForChat(input); + HttpResponseMessage response = await _client.SendAsync(request, cancellationToken); + + if (response.StatusCode is HttpStatusCode.UnprocessableContent) + { + // The AzCLI handler returns status code 422 when the query is out of scope. + // In this case, we don't save the question to the history. + _chatHistory.RemoveAt(_chatHistory.Count - 1); + } + else + { + // Throws if it was not a success response. + response.EnsureSuccessStatusCode(); + } + + context?.Status("Receiving Payload ..."); + var content = await response.Content.ReadAsStreamAsync(cancellationToken); + return JsonSerializer.Deserialize(content, Utils.JsonOptions); + } + catch (Exception exception) + { + // We don't save the question to history when we failed to get a response. + // Check on history count in case the exception is thrown from token refreshing at the very beginning. + if (_chatHistory.Count > 0) + { + _chatHistory.RemoveAt(_chatHistory.Count - 1); + } + + // Re-throw unless the operation was cancelled by user. + if (exception is not OperationCanceledException) + { + throw; + } + } + + return null; + } +} diff --git a/shell/agents/AIShell.Azure.Agent/AzCLI/AzCLISchema.cs b/shell/agents/AIShell.Azure.Agent/AzCLI/AzCLISchema.cs new file mode 100644 index 00000000..c97ceed0 --- /dev/null +++ b/shell/agents/AIShell.Azure.Agent/AzCLI/AzCLISchema.cs @@ -0,0 +1,57 @@ +using System.Text.Json.Serialization; + +namespace AIShell.Azure.CLI; + +internal class Query +{ + public List Messages { get; set; } +} + +internal class CommandItem +{ + public string Desc { get; set; } + public string Script { get; set; } +} + +internal class PlaceholderItem +{ + public string Name { get; set; } + public string Desc { get; set; } + public string Type { get; set; } + + [JsonPropertyName("valid_values")] + public List ValidValues { get; set; } +} + +internal class ResponseData +{ + public string Description { get; set; } + public List CommandSet { get; set; } + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List PlaceholderSet { get; set; } +} + +internal class AzCliResponse +{ + public int Status { get; set; } + public string Error { get; set; } + public ResponseData Data { get; set; } +} + +internal class ArgumentPlaceholder +{ + internal ArgumentPlaceholder(string query, ResponseData data) + { + ArgumentException.ThrowIfNullOrEmpty(query); + ArgumentNullException.ThrowIfNull(data); + + Query = query; + ResponseData = data; + DataRetriever = new(data); + } + + public string Query { get; set; } + public ResponseData ResponseData { get; set; } + public DataRetriever DataRetriever { get; } +} diff --git a/shell/agents/AIShell.Azure.Agent/AzCLI/Command.cs b/shell/agents/AIShell.Azure.Agent/AzCLI/Command.cs new file mode 100644 index 00000000..12bc93f1 --- /dev/null +++ b/shell/agents/AIShell.Azure.Agent/AzCLI/Command.cs @@ -0,0 +1,278 @@ +using System.CommandLine; +using System.Text; +using System.Text.Json; +using AIShell.Abstraction; + +namespace AIShell.Azure.CLI; + +internal sealed class ReplaceCommand : CommandBase +{ + private readonly AzCLIAgent _agent; + private readonly Dictionary _values; + private readonly Dictionary _pseudoValues; + private readonly HashSet _productNames; + private readonly HashSet _environmentNames; + + public ReplaceCommand(AzCLIAgent agent) + : base("replace", "Replace argument placeholders in the generated scripts with the real value.") + { + _agent = agent; + _values = []; + _pseudoValues = []; + _productNames = []; + _environmentNames = []; + + this.SetHandler(ReplaceAction); + } + + private static string SyntaxHighlightAzCommand(string command, string parameter, string placeholder) + { + const string vtItalic = "\x1b[3m"; + const string vtCommand = "\x1b[93m"; + const string vtParameter = "\x1b[90m"; + const string vtVariable = "\x1b[92m"; + const string vtFgDefault = "\x1b[39m"; + const string vtReset = "\x1b[0m"; + + StringBuilder cStr = new(capacity: command.Length + parameter.Length + placeholder.Length + 50); + cStr.Append(vtItalic) + .Append(vtCommand).Append("az").Append(vtFgDefault).Append(command.AsSpan(2)).Append(' ') + .Append(vtParameter).Append(parameter).Append(vtFgDefault).Append(' ') + .Append(vtVariable).Append(placeholder).Append(vtFgDefault) + .Append(vtReset); + + return cStr.ToString(); + } + + private void ReplaceAction() + { + _values.Clear(); + _pseudoValues.Clear(); + _productNames.Clear(); + _environmentNames.Clear(); + + IHost host = Shell.Host; + ArgumentPlaceholder ap = _agent.ArgPlaceholder; + UserValueStore uvs = _agent.ValueStore; + + if (ap is null) + { + host.WriteErrorLine("No argument placeholder to replace."); + return; + } + + DataRetriever dataRetriever = ap.DataRetriever; + List items = ap.ResponseData.PlaceholderSet; + string subText = items.Count > 1 + ? $"all {items.Count} argument placeholders" + : "the argument placeholder"; + host.WriteLine($"\nWe'll provide assistance in replacing {subText} and regenerating the result. You can press 'Enter' to skip to the next parameter or press 'Ctrl+c' to exit the assistance.\n"); + host.RenderDivider("Input Values"); + host.WriteLine(); + + try + { + for (int i = 0; i < items.Count; i++) + { + var item = items[i]; + var (command, parameter) = dataRetriever.GetMappedCommand(item.Name); + + string desc = item.Desc.TrimEnd('.'); + string coloredCmd = parameter is null ? null : SyntaxHighlightAzCommand(command, parameter, item.Name); + string cmdPart = coloredCmd is null ? null : $" [{coloredCmd}]"; + + host.WriteLine(item.Type is "string" + ? $"{i+1}. {desc}{cmdPart}" + : $"{i+1}. {desc}{cmdPart}. Value type: {item.Type}"); + + // Get the task for creating the 'ArgumentInfo' object and show a spinner + // if we have to wait for the task to complete. + Task argInfoTask = dataRetriever.GetArgInfo(item.Name); + ArgumentInfo argInfo = argInfoTask.IsCompleted + ? argInfoTask.Result + : host.RunWithSpinnerAsync( + () => WaitForArgInfoAsync(argInfoTask), + status: $"Requesting data for '{item.Name}' ...", + SpinnerKind.Processing).GetAwaiter().GetResult(); + + argInfo ??= new ArgumentInfo(item.Name, item.Desc, Enum.Parse(item.Type)); + + // Write out restriction for this argument if there is any. + if (!string.IsNullOrEmpty(argInfo.Restriction)) + { + host.WriteLine(argInfo.Restriction); + } + + ArgumentInfoWithNamingRule nameArgInfo = null; + if (argInfo is ArgumentInfoWithNamingRule v) + { + nameArgInfo = v; + SuggestForResourceName(nameArgInfo.NamingRule, nameArgInfo.Suggestions); + } + + // Prompt for argument without printing captions again. + string value = host.PromptForArgument(argInfo, printCaption: false); + if (!string.IsNullOrEmpty(value)) + { + string pseudoValue = uvs.SaveUserInputValue(value); + _values.Add(item.Name, value); + _pseudoValues.Add(item.Name, pseudoValue); + + if (nameArgInfo is not null && nameArgInfo.NamingRule.TryMatchName(value, out string prodName, out string envName)) + { + _productNames.Add(prodName.ToLower()); + _environmentNames.Add(envName.ToLower()); + } + } + + // Write an extra new line. + host.WriteLine(); + } + } + catch (OperationCanceledException) + { + bool proceed = false; + if (_values.Count > 0) + { + host.WriteLine(); + proceed = host.PromptForConfirmationAsync( + "Would you like to regenerate with the provided values so far?", + defaultValue: false, + CancellationToken.None).GetAwaiter().GetResult(); + host.WriteLine(); + } + + if (!proceed) + { + host.WriteLine(); + return; + } + } + + if (_values.Count > 0) + { + host.RenderDivider("Summary"); + host.WriteLine("\nThe following placeholders will be replace:"); + host.RenderList(_values); + + host.RenderDivider("Regenerate"); + host.MarkupLine($"\nQuery: [teal]{ap.Query}[/]"); + + try + { + string answer = host.RunWithSpinnerAsync(RegenerateAsync).GetAwaiter().GetResult(); + host.RenderFullResponse(answer); + } + catch (OperationCanceledException) + { + // User cancelled the operation. + } + } + else + { + host.WriteLine("No value was specified for any of the argument placeholders."); + } + } + + private void SuggestForResourceName(NamingRule rule, IList suggestions) + { + if (_productNames.Count is 0) + { + return; + } + + foreach (string prodName in _productNames) + { + if (_environmentNames.Count is 0) + { + suggestions.Add($"{prodName}-{rule.Abbreviation}"); + continue; + } + + foreach (string envName in _environmentNames) + { + suggestions.Add($"{prodName}-{rule.Abbreviation}-{envName}"); + } + } + } + + private async Task WaitForArgInfoAsync(Task argInfoTask) + { + var token = Shell.CancellationToken; + var cts = CancellationTokenSource.CreateLinkedTokenSource(token); + + // Do not let the user wait for more than 2 seconds. + var delayTask = Task.Delay(2000, cts.Token); + var completedTask = await Task.WhenAny(argInfoTask, delayTask); + + if (completedTask == delayTask) + { + if (delayTask.IsCanceled) + { + // User cancelled the operation. + throw new OperationCanceledException(token); + } + + // Timed out. Last try to see if it finished. Otherwise, return null. + return argInfoTask.IsCompletedSuccessfully ? argInfoTask.Result : null; + } + + // Finished successfully, so we cancel the delay task and return the result. + cts.Cancel(); + return argInfoTask.Result; + } + + /// + /// We use the pseudo values to regenerate the response data, so that real values will never go off the user's box. + /// + /// + private async Task RegenerateAsync() + { + ArgumentPlaceholder ap = _agent.ArgPlaceholder; + StringBuilder prompt = new(capacity: ap.Query.Length + _pseudoValues.Count * 15); + prompt.Append("Regenerate for the last query using the following values specified for the argument placeholders.\n\n"); + + // We use the pseudo values when building the new prompt, because the new prompt + // will be added to history, and we don't want real values to go off the box. + foreach (var entry in _pseudoValues) + { + prompt.Append($"{entry.Key}: {entry.Value}\n"); + } + + // We are doing the replacement locally, but want to fake the regeneration. + await Task.Delay(2000, Shell.CancellationToken); + + ResponseData data = ap.ResponseData; + foreach (CommandItem command in data.CommandSet) + { + foreach (var entry in _pseudoValues) + { + command.Script = command.Script.Replace(entry.Key, entry.Value, StringComparison.OrdinalIgnoreCase); + } + } + + List placeholders = data.PlaceholderSet; + if (placeholders.Count == _pseudoValues.Count) + { + data.PlaceholderSet = null; + } + else if (placeholders.Count > _pseudoValues.Count) + { + List newList = new(placeholders.Count - _pseudoValues.Count); + foreach (PlaceholderItem item in placeholders) + { + if (!_pseudoValues.ContainsKey(item.Name)) + { + newList.Add(item); + } + } + + data.PlaceholderSet = newList; + } + + _agent.AddMessageToHistory(prompt.ToString(), fromUser: true); + _agent.AddMessageToHistory(JsonSerializer.Serialize(data, Utils.JsonOptions), fromUser: false); + + return _agent.GenerateAnswer(ap.Query, data); + } +} diff --git a/shell/agents/AIShell.Azure.Agent/AzCLI/DataRetriever.cs b/shell/agents/AIShell.Azure.Agent/AzCLI/DataRetriever.cs new file mode 100644 index 00000000..f3aaa852 --- /dev/null +++ b/shell/agents/AIShell.Azure.Agent/AzCLI/DataRetriever.cs @@ -0,0 +1,781 @@ +using System.Collections.Concurrent; +using System.ComponentModel; +using System.Diagnostics; +using System.Text.Json; +using System.Text.RegularExpressions; +using AIShell.Abstraction; + +namespace AIShell.Azure.CLI; + +internal class DataRetriever : IDisposable +{ + private static readonly Dictionary s_azNamingRules; + private static readonly ConcurrentDictionary s_azStaticDataCache; + + private readonly string _staticDataRoot; + private readonly Task _rootTask; + private readonly SemaphoreSlim _semaphore; + private readonly List _placeholders; + private readonly Dictionary _placeholderMap; + + private bool _stop; + + static DataRetriever() + { + List rules = [ + new("API Management Service", + "apim", + "The name only allows alphanumeric characters and hyphens, and the first character must be a letter. Length: 1 to 50 chars.", + "az apim create --name", + "New-AzApiManagement -Name"), + + new("Function App", + "func", + "The name only allows alphanumeric characters and hyphens, and cannot start or end with a hyphen. Length: 2 to 60 chars.", + "az functionapp create --name", + "New-AzFunctionApp -Name"), + + new("App Service Plan", + "asp", + "The name only allows alphanumeric characters and hyphens. Length: 1 to 60 chars.", + "az appservice plan create --name", + "New-AzAppServicePlan -Name"), + + new("Web App", + "web", + "The name only allows alphanumeric characters and hyphens. Length: 2 to 43 chars.", + "az webapp create --name", + "New-AzWebApp -Name"), + + new("Application Gateway", + "agw", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start with a letter or number, and end with a letter, number or underscore. Length: 1 - 80 chars.", + "az network application-gateway create --name", + "New-AzApplicationGateway -Name"), + + new("Application Insights", + "ai", + "The name only allows alphanumeric characters, underscores, periods, hyphens and parenthesis, and cannot end in a period. Length: 1 to 255 chars.", + "az monitor app-insights component create --app", + "New-AzApplicationInsights -Name"), + + new("Application Security Group", + "asg", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start with a letter or number, and end with a letter, number or underscore. Length: 1 to 80 chars.", + "az network asg create --name", + "New-AzApplicationSecurityGroup -Name"), + + new("Automation Account", + "aa", + "The name only allows alphanumeric characters and hyphens, and cannot start or end with a hyphen. Length: 6 to 50 chars.", + "az automation account create --name", + "New-AzAutomationAccount -Name"), + + new("Availability Set", + "as", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start with a letter or number, and end with a letter, number or underscore. Length: 1 to 80 chars.", + "az vm availability-set create --name", + "New-AzAvailabilitySet -Name"), + + new("Redis Cache", + "redis", + "The name only allows alphanumeric characters and hyphens, and cannot start or end with a hyphen. Consecutive hyphens are not allowed. Length: 1 to 63 chars.", + "az redis create --name", + "New-AzRedisCache -Name"), + + new("Cognitive Service", + "cogs", + "The name only allows alphanumeric characters and hyphens, and cannot start or end with a hyphen. Length: 2 to 64 chars.", + "az cognitiveservices account create --name", + "New-AzCognitiveServicesAccount -Name"), + + new("Cosmos DB", + "cosmos", + "The name only allows lowercase letters, numbers, and hyphens, and cannot start or end with a hyphen. Length: 3 to 44 chars.", + "az cosmosdb create --name", + "New-AzCosmosDBAccount -Name"), + + new("Event Hubs Namespace", + "eh", + "The name only allows alphanumeric characters and hyphens. It must start with a letter and end with a letter or number. Length: 6 to 50 chars.", + "az eventhubs namespace create --name", + "New-AzEventHubNamespace -Name"), + + new("Event Hubs", + abbreviation: null, + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start and end with a letter or number. Length: 1 to 256 chars.", + "az eventhubs eventhub create --name", + "New-AzEventHub -Name"), + + new("Key Vault", + "kv", + "The name only allows alphanumeric characters and hyphens. It must start with a letter and end with a letter or number. Consecutive hyphens are not allowed. Length: 3 to 24 chars.", + "az keyvault create --name", + "New-AzKeyVault -Name"), + + new("Load Balancer", + "lb", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start with a letter or number, and end with a letter, number or underscore. Length: 1 to 80 chars.", + "az network lb create --name", + "New-AzLoadBalancer -Name"), + + new("Log Analytics workspace", + "la", + "The name only allows alphanumeric characters and hyphens, and cannot start or end with a hyphen. Length: 4 to 63 chars.", + "az monitor log-analytics workspace create --name", + "New-AzOperationalInsightsWorkspace -Name"), + + new("Logic App", + "lapp", + "The name only allows alphanumeric characters and hyphens, and cannot start or end with a hyphen. Length: 2 to 64 chars.", + "az logic workflow create --name", + "New-AzLogicApp -Name"), + + new("Machine Learning workspace", + "mlw", + "The name only allows alphanumeric characters, underscores, and hyphens. It must start with a letter or number. Length: 3 to 33 chars.", + "az ml workspace create --name", + "New-AzMLWorkspace -Name"), + + new("Network Interface", + "nic", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start with a letter or number, and end with a letter, number or underscore. Length: 2 to 64 chars.", + "az network nic create --name", + "New-AzNetworkInterface -Name"), + + new("Network Security Group", + "nsg", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start with a letter or number, and end with a letter, number or underscore. Length: 2 to 64 chars.", + "az network nsg create --name", + "New-AzNetworkSecurityGroup -Name"), + + new("Notification Hub Namespace", + "nh", + "The name only allows alphanumeric characters and hyphens. It must start with a letter and end with a letter or number. Length: 6 to 50 chars.", + "az notification-hub namespace create --name", + "New-AzNotificationHubsNamespace -Namespace"), + + new("Notification Hub", + abbreviation: null, + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start and end with a letter or number. Length: 1 to 260 chars.", + "az notification-hub create --name", + "New-AzNotificationHub -Name"), + + new("Public IP address", + "pip", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start with a letter or number, and end with a letter, number or underscore. Length: 1 to 80 chars.", + "az network public-ip create --name", + "New-AzPublicIpAddress -Name"), + + new("Resource Group", + "rg", + "Resource group names can only include alphanumeric, underscore, parentheses, hyphen, period (except at end), and Unicode characters that match the allowed characters. Length: 1 to 90 chars.", + "az group create --name", + "New-AzResourceGroup -Name"), + + new("Route table", + "rt", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start and end with a letter or number. Length: 1 to 80 chars.", + "az network route-table create --name", + "New-AzRouteTable -Name"), + + new("Search Service", + "srch", + "Service name must only contain lowercase letters, digits or dashes, cannot use dash as the first two or last one characters, and cannot contain consecutive dashes. Length: 2 to 60 chars.", + "az search service create --name", + "New-AzSearchService -Name"), + + new("Service Bus Namespace", + "sb", + "The name only allows alphanumeric characters and hyphens. It must start with a letter and end with a letter or number. Length: 6 to 50 chars.", + "az servicebus namespace create --name", + "New-AzServiceBusNamespace -Name"), + + new("Service Bus queue", + abbreviation: null, + "The name only allows alphanumeric characters and hyphens. It must start with a letter and end with a letter or number. Length: 6 to 50 chars.", + "az servicebus queue create --name", + "New-AzServiceBusQueue -Name"), + + new("Azure SQL Managed Instance", + "sqlmi", + "The name can only contain lowercase letters, numbers and hyphens. It cannot start or end with a hyphen, nor can it have two consecutive hyphens in the third and fourth places of the name. Length: 1 to 63 chars.", + "az sql mi create --name", + "New-AzSqlInstance -Name"), + + new("SQL Server", + "sqldb", + "The name can only contain lowercase letters, numbers and hyphens. It cannot start or end with a hyphen, nor can it have two consecutive hyphens in the third and fourth places of the name. Length: 1 to 63 chars.", + "az sql server create --name", + "New-AzSqlServer -ServerName"), + + new("Storage Container", + abbreviation: null, + "The name can only contain lowercase letters, numbers and hyphens. It must start with a letter or a number, and each hyphen must be preceded and followed by a non-hyphen character. Length: 3 to 63 chars.", + "az storage container create --name", + "New-AzStorageContainer -Name"), + + new("Storage Queue", + abbreviation: null, + "The name can only contain lowercase letters, numbers and hyphens. It must start with a letter or a number, and each hyphen must be preceded and followed by a non-hyphen character. Length: 3 to 63 chars.", + "az storage queue create --name", + "New-AzStorageQueue -Name"), + + new("Storage Table", + abbreviation: null, + "The name can only contain letters and numbers, and must start with a letter. Length: 3 to 63 chars.", + "az storage table create --name", + "New-AzStorageTable -Name"), + + new("Storage File Share", + abbreviation: null, + "The name can only contain lowercase letters, numbers and hyphens. It must start and end with a letter or number, and cannot contain two consecutive hyphens. Length: 3 to 63 chars.", + "az storage share create --name", + "New-AzStorageShare -Name"), + + new("Container Registry", + "cr", + "The name only allows alphanumeric characters. Length: 5 to 50 chars.", + "cr[][]", + ["crnavigatorprod001", "crhadoopdev001"], + "az acr create --name", + "New-AzContainerRegistry -Name"), + + new("Storage Account", + "st", + "The name can only contain lowercase letters and numbers. Length: 3 to 24 chars.", + "st[][]", + ["stsalesappdataqa", "sthadoopoutputtest"], + "az storage account create --name", + "New-AzStorageAccount -Name"), + + new("Traffic Manager profile", + "tm", + "The name only allows alphanumeric characters and hyphens, and cannot start or end with a hyphen. Length: 1 to 63 chars.", + "az network traffic-manager profile create --name", + "New-AzTrafficManagerProfile -Name"), + + new("Virtual Machine", + "vm", + @"The name cannot contain special characters \/""[]:|<>+=;,?*@&#%, whitespace, or begin with '_' or end with '.' or '-'. Length: 1 to 15 chars for Windows; 1 to 64 chars for Linux.", + "az vm create --name", + "New-AzVM -Name"), + + new("Virtual Network Gateway", + "vgw", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start with a letter or number, and end with a letter, number or underscore. Length: 1 to 80 chars.", + "az network vnet-gateway create --name", + "New-AzVirtualNetworkGateway -Name"), + + new("Local Network Gateway", + "lgw", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start with a letter or number, and end with a letter, number or underscore. Length: 1 to 80 chars.", + "az network local-gateway create --name", + "New-AzLocalNetworkGateway -Name"), + + new("Virtual Network", + "vnet", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start with a letter or number, and end with a letter, number or underscore. Length: 1 to 80 chars.", + "az network vnet create --name", + "New-AzVirtualNetwork -Name"), + + new("Subnet", + "snet", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start with a letter or number, and end with a letter, number or underscore. Length: 1 to 80 chars.", + "az network vnet subnet create --name", + "Add-AzVirtualNetworkSubnetConfig -Name"), + + new("VPN Connection", + "vcn", + "The name only allows alphanumeric characters, underscores, periods, and hyphens. It must start with a letter or number, and end with a letter, number or underscore. Length: 1 to 80 chars.", + "az network vpn-connection create --name", + "New-AzVpnConnection -Name"), + ]; + + s_azNamingRules = new(capacity: rules.Count * 2, StringComparer.OrdinalIgnoreCase); + foreach (var rule in rules) + { + s_azNamingRules.Add(rule.AzCLICommand, rule); + s_azNamingRules.Add(rule.AzPSCommand, rule); + } + + s_azStaticDataCache = new(StringComparer.OrdinalIgnoreCase); + } + + internal DataRetriever(ResponseData data) + { + _stop = false; + _semaphore = new SemaphoreSlim(3, 3); + _staticDataRoot = @"E:\yard\tmp\az-cli-out\az"; + _placeholders = new(capacity: data.PlaceholderSet.Count); + _placeholderMap = new(capacity: data.PlaceholderSet.Count); + + PairPlaceholders(data); + _rootTask = Task.Run(StartProcessing); + } + + private void PairPlaceholders(ResponseData data) + { + var cmds = new Dictionary(data.CommandSet.Count); + + foreach (var item in data.PlaceholderSet) + { + string command = null, parameter = null; + + foreach (var cmd in data.CommandSet) + { + string script = cmd.Script; + + // Handle AzCLI commands. + if (script.StartsWith("az ", StringComparison.OrdinalIgnoreCase)) + { + if (!cmds.TryGetValue(script, out command)) + { + int firstParamIndex = script.IndexOf("--"); + command = script.AsSpan(0, firstParamIndex).Trim().ToString(); + cmds.Add(script, command); + } + + int argIndex = script.IndexOf(item.Name, StringComparison.OrdinalIgnoreCase); + if (argIndex is -1) + { + continue; + } + + int paramIndex = script.LastIndexOf("--", argIndex); + parameter = script.AsSpan(paramIndex, argIndex - paramIndex).Trim().ToString(); + + break; + } + + // It's a non-AzCLI command, such as "ssh". + if (script.Contains(item.Name, StringComparison.OrdinalIgnoreCase)) + { + // Leave the parameter to be null for non-AzCLI commands, as there is + // no reliable way to parse an arbitrary command + command = script; + parameter = null; + + break; + } + } + + ArgumentPair pair = new(item, command, parameter); + _placeholders.Add(pair); + _placeholderMap.Add(item.Name, pair); + } + } + + private void StartProcessing() + { + foreach (var pair in _placeholders) + { + if (_stop) { break; } + + _semaphore.Wait(); + + if (pair.ArgumentInfo is null) + { + lock (pair) + { + if (pair.ArgumentInfo is null) + { + pair.ArgumentInfo = Task.Factory.StartNew(ProcessOne, pair); + continue; + } + } + } + + _semaphore.Release(); + } + + ArgumentInfo ProcessOne(object pair) + { + try + { + return CreateArgInfo((ArgumentPair)pair); + } + finally + { + _semaphore.Release(); + } + } + } + + private ArgumentInfo CreateArgInfo(ArgumentPair pair) + { + var item = pair.Placeholder; + var dataType = Enum.Parse(item.Type, ignoreCase: true); + + if (item.ValidValues?.Count > 0) + { + return new ArgumentInfo(item.Name, item.Desc, restriction: null, dataType, item.ValidValues); + } + + // Handle non-AzCLI command. + if (pair.Parameter is null) + { + return new ArgumentInfo(item.Name, item.Desc, dataType); + } + + string cmdAndParam = $"{pair.Command} {pair.Parameter}"; + if (s_azNamingRules.TryGetValue(cmdAndParam, out NamingRule rule)) + { + string restriction = rule.PatternText is null + ? rule.GeneralRule + : $""" + - {rule.GeneralRule} + - Recommended pattern: {rule.PatternText}, e.g. {string.Join(", ", rule.Example)}. + """; + return new ArgumentInfoWithNamingRule(item.Name, item.Desc, restriction, rule); + } + + if (string.Equals(pair.Parameter, "--name", StringComparison.OrdinalIgnoreCase) + && pair.Command.EndsWith(" create", StringComparison.OrdinalIgnoreCase)) + { + // Placeholder is for the name of a new resource to be created, but not in our cache. + return new ArgumentInfo(item.Name, item.Desc, dataType); + } + + if (_stop) { return null; } + + List suggestions = GetArgValues(pair, out Option option); + // If the option's description is less than the placeholder's description in length, then it's + // unlikely to provide more information than the latter. In that case, we don't use it. + string optionDesc = option?.Description?.Length > item.Desc.Length ? option.Description : null; + return new ArgumentInfo(item.Name, item.Desc, optionDesc, dataType, suggestions); + } + + private List GetArgValues(ArgumentPair pair, out Option option) + { + // First, try to get static argument values if they exist. + string command = pair.Command; + if (!s_azStaticDataCache.TryGetValue(command, out Command commandData)) + { + string[] cmdElements = command.Split(' ', StringSplitOptions.RemoveEmptyEntries); + string dirPath = _staticDataRoot; + for (int i = 1; i < cmdElements.Length - 1; i++) + { + dirPath = Path.Combine(dirPath, cmdElements[i]); + } + + string filePath = Path.Combine(dirPath, cmdElements[^1] + ".json"); + commandData = File.Exists(filePath) + ? JsonSerializer.Deserialize(File.OpenRead(filePath)) + : null; + s_azStaticDataCache.TryAdd(command, commandData); + } + + option = commandData?.FindOption(pair.Parameter); + List staticValues = option?.Arguments; + if (staticValues?.Count > 0) + { + return staticValues; + } + + if (_stop) { return null; } + + // Then, try to get dynamic argument values using AzCLI tab completion. + string commandLine = $"{pair.Command} {pair.Parameter} "; + string tempFile = Path.GetTempFileName(); + + try + { + using var process = new Process() + { + StartInfo = new ProcessStartInfo() + { + FileName = @"C:\Program Files\Microsoft SDKs\Azure\CLI2\python.exe", + Arguments = "-Im azure.cli", + UseShellExecute = false, + RedirectStandardOutput = true, + RedirectStandardError = true, + } + }; + + var env = process.StartInfo.Environment; + env.Add("ARGCOMPLETE_USE_TEMPFILES", "1"); + env.Add("_ARGCOMPLETE_STDOUT_FILENAME", tempFile); + env.Add("COMP_LINE", commandLine); + env.Add("COMP_POINT", (commandLine.Length + 1).ToString()); + env.Add("_ARGCOMPLETE", "1"); + env.Add("_ARGCOMPLETE_SUPPRESS_SPACE", "0"); + env.Add("_ARGCOMPLETE_IFS", "\n"); + env.Add("_ARGCOMPLETE_SHELL", "powershell"); + + process.Start(); + process.WaitForExit(); + + string line; + using FileStream stream = File.OpenRead(tempFile); + if (stream.Length is 0) + { + // No allowed values for the option. + return null; + } + + using StreamReader reader = new(stream); + List output = []; + + while ((line = reader.ReadLine()) is not null) + { + if (line.StartsWith('-')) + { + // Argument completion generates incorrect results -- options are written into the file instead of argument allowed values. + return null; + } + + string value = line.Trim(); + if (value != string.Empty) + { + output.Add(value); + } + } + + return output.Count > 0 ? output : null; + } + catch (Win32Exception e) + { + throw new ApplicationException($"Failed to get allowed values for 'az {commandLine}': {e.Message}", e); + } + finally + { + if (File.Exists(tempFile)) + { + File.Delete(tempFile); + } + } + } + + internal (string command, string parameter) GetMappedCommand(string placeholderName) + { + if (_placeholderMap.TryGetValue(placeholderName, out ArgumentPair pair)) + { + return (pair.Command, pair.Parameter); + } + + throw new ArgumentException($"Unknown placeholder name: '{placeholderName}'", nameof(placeholderName)); + } + + internal Task GetArgInfo(string placeholderName) + { + if (_placeholderMap.TryGetValue(placeholderName, out ArgumentPair pair)) + { + if (pair.ArgumentInfo is null) + { + lock (pair) + { + pair.ArgumentInfo ??= Task.Run(() => CreateArgInfo(pair)); + } + } + + return pair.ArgumentInfo; + } + + throw new ArgumentException($"Unknown placeholder name: '{placeholderName}'", nameof(placeholderName)); + } + + public void Dispose() + { + _stop = true; + _rootTask.Wait(); + _semaphore.Dispose(); + } +} + +internal class ArgumentPair +{ + internal PlaceholderItem Placeholder { get; } + internal string Command { get; } + internal string Parameter { get; } + internal Task ArgumentInfo { set; get; } + + internal ArgumentPair(PlaceholderItem placeholder, string command, string parameter) + { + Placeholder = placeholder; + Command = command; + Parameter = parameter; + ArgumentInfo = null; + } +} + +internal class ArgumentInfoWithNamingRule : ArgumentInfo +{ + internal ArgumentInfoWithNamingRule(string name, string description, string restriction, NamingRule rule) + : base(name, description, restriction, DataType.@string, suggestions: []) + { + ArgumentNullException.ThrowIfNull(rule); + NamingRule = rule; + } + + internal NamingRule NamingRule { get; } +} + +internal class NamingRule +{ + private static readonly string[] s_products = ["salesapp", "bookingweb", "navigator", "hadoop", "sharepoint"]; + private static readonly string[] s_envs = ["prod", "dev", "qa", "stage", "test"]; + + internal string ResourceName { get; } + internal string Abbreviation { get; } + internal string GeneralRule { get; } + internal string PatternText { get; } + internal Regex PatternRegex { get; } + internal string[] Example { get; } + + internal string AzCLICommand { get; } + internal string AzPSCommand { get; } + + internal NamingRule( + string resourceName, + string abbreviation, + string generalRule, + string azCLICommand, + string azPSCommand) + { + ArgumentException.ThrowIfNullOrEmpty(resourceName); + ArgumentException.ThrowIfNullOrEmpty(generalRule); + ArgumentException.ThrowIfNullOrEmpty(azCLICommand); + ArgumentException.ThrowIfNullOrEmpty(azPSCommand); + + ResourceName = resourceName; + Abbreviation = abbreviation; + GeneralRule = generalRule; + AzCLICommand = azCLICommand; + AzPSCommand = azPSCommand; + + if (abbreviation is not null) + { + PatternText = $"-{abbreviation}[-][-]"; + PatternRegex = new Regex($"^(?[a-zA-Z0-9]+)-{abbreviation}(?:-(?[a-zA-Z0-9]+))?(?:-[a-zA-Z0-9]+)?$", RegexOptions.Compiled); + + string product = s_products[Random.Shared.Next(0, s_products.Length)]; + int envIndex = Random.Shared.Next(0, s_envs.Length); + Example = [$"{product}-{abbreviation}-{s_envs[envIndex]}", $"{product}-{abbreviation}-{s_envs[(envIndex + 1) % s_envs.Length]}"]; + } + } + + internal NamingRule( + string resourceName, + string abbreviation, + string generalRule, + string patternText, + string[] example, + string azCLICommand, + string azPSCommand) + { + ArgumentException.ThrowIfNullOrEmpty(resourceName); + ArgumentException.ThrowIfNullOrEmpty(generalRule); + ArgumentException.ThrowIfNullOrEmpty(azCLICommand); + ArgumentException.ThrowIfNullOrEmpty(azPSCommand); + + ResourceName = resourceName; + Abbreviation = abbreviation; + GeneralRule = generalRule; + PatternText = patternText; + PatternRegex = null; + Example = example; + + AzCLICommand = azCLICommand; + AzPSCommand = azPSCommand; + } + + internal bool TryMatchName(string name, out string prodName, out string envName) + { + prodName = envName = null; + if (PatternRegex is null) + { + return false; + } + + Match match = PatternRegex.Match(name); + if (match.Success) + { + prodName = match.Groups["prod"].Value; + envName = match.Groups["env"].Value; + return true; + } + + return false; + } +} + +public class Option +{ + public string Name { get; } + public string[] Alias { get; } + public string[] Short { get; } + public string Attribute { get; } + public string Description { get; set; } + public List Arguments { get; set; } + + public Option(string name, string description, string[] alias, string[] @short, string attribute, List arguments) + { + ArgumentException.ThrowIfNullOrEmpty(name); + ArgumentException.ThrowIfNullOrEmpty(description); + + Name = name; + Alias = alias; + Short = @short; + Attribute = attribute; + Description = description; + Arguments = arguments; + } +} + +public sealed class Command +{ + public List