From a4b0e510f5ae3a06bdc536ba68f4edd34f701996 Mon Sep 17 00:00:00 2001 From: Ashlee Nanze Date: Wed, 21 Jun 2023 23:57:53 -0700 Subject: [PATCH] Updates --- src/code/OpenAI.cs | 63 ++++++++++++++++++++++++---------------- src/code/Policy.cs | 35 ++-------------------- src/code/Screenbuffer.cs | 2 +- 3 files changed, 42 insertions(+), 58 deletions(-) diff --git a/src/code/OpenAI.cs b/src/code/OpenAI.cs index c3eaec1c..1b8e11a7 100644 --- a/src/code/OpenAI.cs +++ b/src/code/OpenAI.cs @@ -7,15 +7,16 @@ using System.Threading.Tasks; using Azure.AI.OpenAI; using Azure; +using Azure.Core; namespace Microsoft.PowerShell.Copilot { internal class OpenAI { + private const string API_ENV_VAR = "AZURE_OPENAI_API_KEY"; internal const string ENDPOINT_ENV_VAR = "AZURE_OPENAI_ENDPOINT"; internal const string SYSTEM_PROMPT_ENV_VAR = "AZURE_OPENAI_SYSTEM_PROMPT"; - private const string API_ENV_VAR = "AZURE_OPENAI_API_KEY"; private static readonly string[] SPINNER = new string[8] {"🌑", "🌒", "🌓", "🌔", "🌕", "🌖", "🌗", "🌘"}; private static List _promptHistory = new(); @@ -36,18 +37,30 @@ public OpenAI() } OpenAIClientOptions options = new OpenAIClientOptions(); + options.Retry.MaxRetries = 0; + string apiKey = Environment.GetEnvironmentVariable(API_ENV_VAR); if (apiKey is null) { throw(new Exception($"{API_ENV_VAR} environment variable not set")); } - //adds policy - AzureKeyCredentialPolicy policy = new AzureKeyCredentialPolicy(new AzureKeyCredential(apiKey), "Ocp-Apim-Subscription-Key"); - options.AddPolicy(policy, Azure.Core.HttpPipelinePosition.PerRetry); - //creates client - client = new OpenAIClient(new Uri(endpoint), new AzureKeyCredential(apiKey), options); + if (endpoint.EndsWith(".azure-api.net", StringComparison.Ordinal)) + { + AzureKeyCredentialPolicy policy = new AzureKeyCredentialPolicy(new AzureKeyCredential(apiKey), "Ocp-Apim-Subscription-Key"); + options.AddPolicy(policy, Azure.Core.HttpPipelinePosition.PerRetry); + + client = new OpenAIClient(new Uri(endpoint), new AzureKeyCredential("placeholder"), options); + } + else if (endpoint.EndsWith(".openai.azure.com", StringComparison.Ordinal)) + { + client = new OpenAIClient(new Uri(endpoint), new AzureKeyCredential(apiKey)); + } + else + { + throw new Exception($"The specified endpoint '{endpoint}' is not a valid Azure OpenAI service endpoint."); + } } internal string LastCodeSnippet() @@ -166,35 +179,35 @@ internal string GetCompletion(string prompt, bool debug, CancellationToken cance Console.WriteLine($"{PSStyle.Instance.Foreground.BrightMagenta}DEBUG: OpenAI URL: {endpoint}"); } - string modelName = ""; + string openai_model = ""; switch (EnterCopilot._model) - { - case Model.GPT4: - modelName = "gpt4"; - break; - case Model.GPT4_32K: - modelName = "gpt4-32k"; - break; - default: - modelName ="gpt-35-turbo"; - break; - } + { + case Model.GPT35_Turbo: + openai_model = "gpt-35-turbo"; + break; + case Model.GPT4_32K: + openai_model = "gpt4-32k"; + break; + default: + openai_model = "gpt4"; + break; + } - Response response = client.GetChatCompletions( - deploymentOrModelName: modelName, + deploymentOrModelName: openai_model, requestBody); - ChatCompletions chatCompletions = response.Value; var output = "\n"; - foreach (ChatChoice choice in chatCompletions.Choices) - { - output += choice.Message.Content; - } + output += chatCompletions.Choices[0].Message.Content; return output; + + } + catch (Azure.RequestFailedException e) + { + return $"{PSStyle.Instance.Foreground.BrightRed}HTTP EXCEPTION: {e.Message}"; } catch (OperationCanceledException) { diff --git a/src/code/Policy.cs b/src/code/Policy.cs index 3636d04e..f85f3289 100644 --- a/src/code/Policy.cs +++ b/src/code/Policy.cs @@ -7,54 +7,25 @@ internal class AzureKeyCredentialPolicy : HttpPipelineSynchronousPolicy private readonly string _name; private readonly AzureKeyCredential _credential; - private readonly string _prefix; /// /// Initializes a new instance of the class. /// /// The used to authenticate requests. /// The name of the key header used for the credential. - /// The prefix to apply before the credential key. For example, a prefix of "SharedAccessKey" would result in - /// a value of "SharedAccessKey {credential.Key}" being stamped on the request header with header key of . - public AzureKeyCredentialPolicy(AzureKeyCredential credential, string name, object prefix = null) + + public AzureKeyCredentialPolicy(AzureKeyCredential credential, string name) { _credential = credential; _name = name; - if(_prefix != null) - { - _prefix = (string) prefix; - } } /// public override void OnSendingRequest(Azure.Core.HttpMessage message) { base.OnSendingRequest(message); - message.Request.Headers.SetValue(_name, _prefix != null ? $"{_prefix} {_credential.Key}" : _credential.Key); - } - public override void OnReceivedResponse(Azure.Core.HttpMessage message) - { - base.OnReceivedResponse(message); - if (message.HasResponse == true && message.Response.Status == 429) - { - throw(new RateLimitException(message.Response.Content.ToString())); - } - + message.Request.Headers.SetValue(_name, _credential.Key); } } -internal class RateLimitException : Exception -{ - public RateLimitException(string message) : base(message) { } - - public override string ToString() - { - return Message; - } - - public override string StackTrace - { - get{return "";} - } -} diff --git a/src/code/Screenbuffer.cs b/src/code/Screenbuffer.cs index 6b5893d6..b7c0e2a7 100644 --- a/src/code/Screenbuffer.cs +++ b/src/code/Screenbuffer.cs @@ -47,7 +47,7 @@ internal static void RedrawScreen() else { WriteLineConsole($"{RESET}{LOGO}"); - string openai_url = "GPT-4"; + string openai_url = Environment.GetEnvironmentVariable(OpenAI.ENDPOINT_ENV_VAR); if (openai_url is null) { WriteLineConsole($"{PSStyle.Instance.Foreground.Yellow}Using {EnterCopilot._model}");