From af53ce223041a4e0a10736aadb62820268027fe9 Mon Sep 17 00:00:00 2001 From: Rainer Date: Tue, 23 Apr 2024 08:39:53 +0200 Subject: [PATCH] Add assistant .NET sample --- labs/.vscode/launch.json | 11 + .../035-assistants-dotnet.csproj | 18 ++ labs/035-assistants-dotnet/Functions.cs | 250 ++++++++++++++++++ labs/035-assistants-dotnet/JsonHelpers.cs | 14 + .../035-assistants-dotnet/OpenAIExtensions.cs | 128 +++++++++ labs/035-assistants-dotnet/Program.cs | 87 ++++++ labs/Directory.Packages.props | 18 +- labs/HandsOnLabs.sln | 6 + readme.md | 4 +- 9 files changed, 527 insertions(+), 9 deletions(-) create mode 100644 labs/035-assistants-dotnet/035-assistants-dotnet.csproj create mode 100644 labs/035-assistants-dotnet/Functions.cs create mode 100644 labs/035-assistants-dotnet/JsonHelpers.cs create mode 100644 labs/035-assistants-dotnet/OpenAIExtensions.cs create mode 100644 labs/035-assistants-dotnet/Program.cs diff --git a/labs/.vscode/launch.json b/labs/.vscode/launch.json index b9bf15d..8bab3bc 100644 --- a/labs/.vscode/launch.json +++ b/labs/.vscode/launch.json @@ -14,6 +14,17 @@ "cwd": "${workspaceFolder}", "stopAtEntry": false, "console": "integratedTerminal" + }, + { + "name": "Launch 035-assistants-dotnet", + "type": "coreclr", + "request": "launch", + "preLaunchTask": "dotnet: build", + "program": "${workspaceFolder}/035-assistants-dotnet/bin/Debug/net8.0/035-assistants-dotnet.dll", + "args": [], + "cwd": "${workspaceFolder}", + "stopAtEntry": false, + "console": "integratedTerminal" } ] diff --git a/labs/035-assistants-dotnet/035-assistants-dotnet.csproj b/labs/035-assistants-dotnet/035-assistants-dotnet.csproj new file mode 100644 index 0000000..11bd754 --- /dev/null +++ b/labs/035-assistants-dotnet/035-assistants-dotnet.csproj @@ -0,0 +1,18 @@ + + + + Exe + net8.0 + _035_assistants_dotnet + enable + enable + + + + + + + + + + diff --git a/labs/035-assistants-dotnet/Functions.cs b/labs/035-assistants-dotnet/Functions.cs new file mode 100644 index 0000000..21fcc55 --- /dev/null +++ b/labs/035-assistants-dotnet/Functions.cs @@ -0,0 +1,250 @@ +using System.Data; +using System.Text; +using Azure.AI.OpenAI.Assistants; +using Dapper; +using Microsoft.Data.SqlClient; + +namespace AssistantsDotNet; + +public static class Functions +{ + public static readonly FunctionToolDefinition GetCustomersFunctionDefinition = new( + "getCustomers", + "Gets a filtered list of customers. At least one filter MUST be provided in the parameters. The result list is limited to 25 customers.", + JsonHelpers.FromObjectAsJson( + new + { + Type = "object", + Properties = new + { + CustomerID = new { Type = "integer", Description = "Optional filter for the customer ID." }, + FirstName = new { Type = "string", Description = "Optional filter for the first name (true if first name contains filter value)." }, + MiddleName = new { Type = "string", Description = "Optional filter for the middle name (true if middle name contains filter value)." }, + LastName = new { Type = "string", Description = "Optional filter for the last name (true if last name contains filter value)." }, + CompanyName = new { Type = "string", Description = "Optional filter for the company name (true if company name contains filter value)." } + }, + Required = Array.Empty() + }) + ); + + public class GetCustomersParameters + { + public int? CustomerID { get; set; } + public string? FirstName { get; set; } + public string? MiddleName { get; set; } + public string? LastName { get; set; } + public string? CompanyName { get; set; } + } + + public class Customer + { + public int CustomerID { get; set; } + public string? FirstName { get; set; } + public string? MiddleName { get; set; } + public string? LastName { get; set; } + public string? CompanyName { get; set; } + } + + public static async Task> GetCustomers(SqlConnection connection, GetCustomersParameters filter) + { + if (!filter.CustomerID.HasValue && string.IsNullOrEmpty(filter.FirstName) && string.IsNullOrEmpty(filter.MiddleName) && string.IsNullOrEmpty(filter.LastName) && string.IsNullOrEmpty(filter.CompanyName)) + { + throw new Exception("At least one filter must be provided."); + } + + var query = new StringBuilder("SELECT TOP 25 CustomerID, FirstName, MiddleName, LastName, CompanyName FROM SalesLT.Customer WHERE CustomerID >= 29485"); + var parameters = new DynamicParameters(); + + if (filter.CustomerID.HasValue) + { + query.Append(" AND CustomerID = @customerID"); + parameters.Add("customerID", filter.CustomerID.Value, DbType.Int32); + } + if (!string.IsNullOrEmpty(filter.FirstName)) + { + query.Append(" AND FirstName LIKE '%' + @firstName + '%'"); + parameters.Add("firstName", filter.FirstName, DbType.String); + } + if (!string.IsNullOrEmpty(filter.MiddleName)) + { + query.Append(" AND MiddleName LIKE '%' + @middleName + '%'"); + parameters.Add("middleName", filter.MiddleName, DbType.String); + } + if (!string.IsNullOrEmpty(filter.LastName)) + { + query.Append(" AND LastName LIKE '%' + @lastName + '%'"); + parameters.Add("lastName", filter.LastName, DbType.String); + } + if (!string.IsNullOrEmpty(filter.CompanyName)) + { + query.Append(" AND CompanyName LIKE '%' + @companyName + '%'"); + parameters.Add("companyName", filter.CompanyName, DbType.String); + } + + Console.WriteLine($"Executing query: {query}"); + + var result = await connection.QueryAsync(query.ToString(), parameters); + + return result; + } + + public static readonly FunctionToolDefinition GetProductsFunctionDefinition = new( + "getProducts", + "Gets a filtered list of products. At least one filter MUST be provided in the parameters. The result list is limited to 25 products.", + JsonHelpers.FromObjectAsJson( + new + { + Type = "object", + Properties = new + { + ProductID = new { Type = "integer", Description = "Optional filter for the product ID." }, + Name = new { Type = "string", Description = "Optional filter for the product name (true if product name contains filter value)." }, + ProductNumber = new { Type = "string", Description = "Optional filter for the product number." } + }, + Required = Array.Empty() + }) + ); + + public class GetProductsParameters + { + public int? ProductID { get; set; } + public string? Name { get; set; } + public string? ProductNumber { get; set; } + } + + public class Product + { + public int ProductID { get; set; } + public string? Name { get; set; } + public string? ProductNumber { get; set; } + public int ProductCategoryID { get; set; } + } + + public static async Task> GetProducts(SqlConnection connection, GetProductsParameters filter) + { + if (!filter.ProductID.HasValue && string.IsNullOrEmpty(filter.Name) && string.IsNullOrEmpty(filter.ProductNumber)) + { + throw new Exception("At least one filter must be provided."); + } + + var query = new StringBuilder("SELECT TOP 25 ProductID, Name, ProductNumber, ProductCategoryID FROM SalesLT.Product WHERE 1 = 1"); + var parameters = new DynamicParameters(); + + if (filter.ProductID.HasValue) + { + query.Append(" AND ProductID = @productID"); + parameters.Add("productID", filter.ProductID.Value, DbType.Int32); + } + if (!string.IsNullOrEmpty(filter.Name)) + { + query.Append(" AND Name LIKE '%' + @name + '%'"); + parameters.Add("name", filter.Name, DbType.String); + } + if (!string.IsNullOrEmpty(filter.ProductNumber)) + { + query.Append(" AND ProductNumber = @productNumber"); + parameters.Add("productNumber", filter.ProductNumber, DbType.String); + } + + Console.WriteLine($"Executing query: {query}"); + + var result = await connection.QueryAsync(query.ToString(), parameters); + + return result; + } + + public static readonly FunctionToolDefinition GetCustomerProductsRevenueFunctionDefinition = new( + "getCustomerProductsRevenue", + "Gets the revenue of the customer and products. The result is ordered by the revenue in descending order. The result list is limited to 25 records.", + JsonHelpers.FromObjectAsJson( + new + { + Type = "object", + Properties = new + { + CustomerID = new { Type = "integer", Description = "Optional filter for the customer ID." }, + ProductID = new { Type = "integer", Description = "Optional filter for the product ID." }, + Year = new { Type = "integer", Description = "Optional filter for the year." }, + Month = new { Type = "integer", Description = "Optional filter for the month." }, + GroupByCustomer = new { Type = "boolean", Description = "If true, revenue is grouped by customer ID." }, + GroupByProduct = new { Type = "boolean", Description = "If true, revenue is grouped by product ID." }, + GroupByYear = new { Type = "boolean", Description = "If true, revenue is grouped by year." }, + GroupByMonth = new { Type = "boolean", Description = "If true, revenue is grouped by month." } + }, + Required = Array.Empty() + }) + ); + + public class GetCustomerProductsRevenueParameters + { + public int? CustomerID { get; set; } + public int? ProductID { get; set; } + public int? Year { get; set; } + public int? Month { get; set; } + public bool? GroupByCustomer { get; set; } + public bool? GroupByProduct { get; set; } + public bool? GroupByYear { get; set; } + public bool? GroupByMonth { get; set; } + } + + public class CustomerProductsRevenue + { + public decimal Revenue { get; set; } + public int? CustomerID { get; set; } + public int? ProductID { get; set; } + public int? Year { get; set; } + public int? Month { get; set; } + } + + public static async Task> GetCustomerProductsRevenue(SqlConnection connection, GetCustomerProductsRevenueParameters filter) + { + var query = new StringBuilder("SELECT TOP 25 SUM(LineTotal) AS Revenue"); + var parameters = new DynamicParameters(); + + if (filter.GroupByCustomer.HasValue && filter.GroupByCustomer.Value) { query.Append(", CustomerID"); } + if (filter.GroupByProduct.HasValue && filter.GroupByProduct.Value) { query.Append(", ProductID"); } + if (filter.GroupByYear.HasValue && filter.GroupByYear.Value) { query.Append(", YEAR(OrderDate) AS Year"); } + if (filter.GroupByMonth.HasValue && filter.GroupByMonth.Value) { query.Append(", MONTH(OrderDate) AS Month"); } + + query.Append(" FROM SalesLT.SalesOrderDetail d INNER JOIN SalesLT.SalesOrderHeader h ON d.SalesOrderID = h.SalesOrderID WHERE 1 = 1"); + + if (filter.CustomerID.HasValue) + { + query.Append(" AND CustomerID = @customerID"); + parameters.Add("customerID", filter.CustomerID.Value, DbType.Int32); + } + if (filter.ProductID.HasValue) + { + query.Append(" AND ProductID = @productID"); + parameters.Add("productID", filter.ProductID.Value, DbType.Int32); + } + if (filter.Year.HasValue) + { + query.Append(" AND YEAR(OrderDate) = @year"); + parameters.Add("year", filter.Year.Value, DbType.Int32); + } + if (filter.Month.HasValue) + { + query.Append(" AND MONTH(OrderDate) = @month"); + parameters.Add("month", filter.Month.Value, DbType.Int32); + } + + if (filter.GroupByCustomer.HasValue || filter.GroupByProduct.HasValue || filter.GroupByYear.HasValue || filter.GroupByMonth.HasValue) + { + var groupColumns = new List(); + if (filter.GroupByCustomer.HasValue && filter.GroupByCustomer.Value) { groupColumns.Add("CustomerID"); } + if (filter.GroupByProduct.HasValue && filter.GroupByProduct.Value) { groupColumns.Add("ProductID"); } + if (filter.GroupByYear.HasValue && filter.GroupByYear.Value) { groupColumns.Add("YEAR(OrderDate)"); } + if (filter.GroupByMonth.HasValue && filter.GroupByMonth.Value) { groupColumns.Add("MONTH(OrderDate)"); } + query.Append($" GROUP BY {string.Join(", ", groupColumns)}"); + } + + query.Append(" ORDER BY SUM(LineTotal) DESC"); + + Console.WriteLine($"Executing query: {query.ToString()}"); + + var result = await connection.QueryAsync(query.ToString(), parameters); + + return result.ToList(); + } +} diff --git a/labs/035-assistants-dotnet/JsonHelpers.cs b/labs/035-assistants-dotnet/JsonHelpers.cs new file mode 100644 index 0000000..294642c --- /dev/null +++ b/labs/035-assistants-dotnet/JsonHelpers.cs @@ -0,0 +1,14 @@ +using System.Text.Json; + +namespace AssistantsDotNet; + +static class JsonHelpers +{ + private static readonly JsonSerializerOptions JsonSerializerOptions = new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + + public static string Serialize(T value) => JsonSerializer.Serialize(value, JsonSerializerOptions); + + public static T Deserialize(string json) => JsonSerializer.Deserialize(json, JsonSerializerOptions)!; + + public static BinaryData FromObjectAsJson(object value) => BinaryData.FromObjectAsJson(value, JsonSerializerOptions); +} diff --git a/labs/035-assistants-dotnet/OpenAIExtensions.cs b/labs/035-assistants-dotnet/OpenAIExtensions.cs new file mode 100644 index 0000000..d39114e --- /dev/null +++ b/labs/035-assistants-dotnet/OpenAIExtensions.cs @@ -0,0 +1,128 @@ +using Azure.AI.OpenAI.Assistants; + +namespace AssistantsDotNet; + +static class OpenAIExtensions +{ + public static async Task FindAssistantByName(this AssistantsClient client, string name) + { + PageableList assistants; + string? after = null; + do + { + assistants = await client.GetAssistantsAsync(after: after); + foreach (var assistant in assistants) + { + if (assistant.Name == name) { return assistant; } + } + + after = assistants.LastId; + } + while (assistants.HasMore); + + return null; + } + + public static async Task CreateOrUpdate(this AssistantsClient client, AssistantCreationOptions assistant) + { + var existing = await client.FindAssistantByName(assistant.Name); + if (existing != null) + { + var updateOptions = new UpdateAssistantOptions() + { + Model = assistant.Model, + Name = assistant.Name, + Description = assistant.Description, + Instructions = assistant.Instructions, + Metadata = assistant.Metadata + }; + foreach (var tool in assistant.Tools) { updateOptions.Tools.Add(tool); } + foreach (var fileId in assistant.FileIds) { updateOptions.FileIds.Add(fileId); } + + return await client.UpdateAssistantAsync(existing.Id, updateOptions); + } + + return await client.CreateAssistantAsync(assistant); + } + + public static async Task AddMessageAndRunToCompletion(this AssistantsClient client, string threadId, string assistantId, + string message, Func>? functionCallback = null) + { + await client.CreateMessageAsync(threadId, MessageRole.User, message); + var run = await client.CreateRunAsync(threadId, new CreateRunOptions(assistantId)); + Console.WriteLine($"Run created { run.Value.Id }"); + + while (run.Value.Status == RunStatus.Queued || run.Value.Status == RunStatus.InProgress || run.Value.Status == RunStatus.Cancelling || run.Value.Status == RunStatus.RequiresAction) + { + Console.WriteLine($"Run status { run.Value.Status }"); + + var steps = await client.GetRunStepsAsync(run, 1, ListSortOrder.Descending); + + // If last step is a code interpreter call, log it (including generated Python code) + if (steps.Value.Any() && steps.Value.First().StepDetails is RunStepToolCallDetails toolCallDetails) + { + foreach(var call in toolCallDetails.ToolCalls) + { + if (call is RunStepCodeInterpreterToolCall codeInterpreterToolCall && !string.IsNullOrEmpty(codeInterpreterToolCall.Input)) + { + Console.WriteLine($"Code Interpreter Tool Call: {codeInterpreterToolCall.Input}"); + } + } + } + + // Check if the run requires us to execute a function + if (run.Value.Status == RunStatus.RequiresAction && functionCallback != null) + { + var toolOutput = new List(); + if (steps.Value.First().StepDetails is RunStepToolCallDetails stepDetails) + { + foreach(var call in stepDetails.ToolCalls.OfType()) + { + Console.WriteLine($"Calling function { call.Id } { call.Name } { call.Arguments }"); + + string functionResponse; + try + { + var result = await functionCallback(call); + functionResponse = JsonHelpers.Serialize(result); + } + catch (Exception ex) + { + Console.WriteLine($"Function call failed, returning error message to ChatGPT { call.Name } { ex.Message }"); + functionResponse = ex.Message; + } + + toolOutput.Add(new() + { + ToolCallId = call.Id, + Output = functionResponse + }); + } + } + + if (toolOutput.Count != 0) + { + run = await client.SubmitToolOutputsToRunAsync(threadId, run.Value.Id, toolOutput); + } + } + + + await Task.Delay(1000); + run = await client.GetRunAsync(threadId, run.Value.Id); + } + + Console.WriteLine($"Final run status { run.Value.Status }"); + return run; + } + + public static async Task GetLatestMessage(this AssistantsClient client, string threadId) + { + var messages = await client.GetMessagesAsync(threadId, 1, ListSortOrder.Descending); + if (messages.Value.FirstOrDefault()?.ContentItems[0] is MessageTextContent tc) + { + return tc.Text; + } + + return null; + } +} diff --git a/labs/035-assistants-dotnet/Program.cs b/labs/035-assistants-dotnet/Program.cs new file mode 100644 index 0000000..3f11ca5 --- /dev/null +++ b/labs/035-assistants-dotnet/Program.cs @@ -0,0 +1,87 @@ +using AssistantsDotNet; +using Azure.AI.OpenAI.Assistants; +using dotenv.net; +using Microsoft.Data.SqlClient; + +// Get environment variables from .env file. We have to go up 7 levels to get to the root of the +// git repository (because of bin/Debug/net8.0 folder). +var env = DotEnv.Read(options: new DotEnvOptions(probeForEnv: true, probeLevelsToSearch: 7)); + +// Open connection to Adventure Works +using var sqlConnection = new SqlConnection(env["ADVENTURE_WORKS"]); +await sqlConnection.OpenAsync(); + +// In this sample, we use key-based authentication. This is only done because this sample +// will be done by a larger group in a hackathon event. In real world, AVOID key-based +// authentication. ALWAYS prefer Microsoft Entra-based authentication (Managed Identity)! +var client = new AssistantsClient(env["OPENAI_KEY"]); + +var assistant = await client.CreateOrUpdate(new(env["OPENAI_MODEL"]) +{ + Name = "Revenue Analyzer", + Description = "Retrieves customer and product revenue and analyzes it using code interpreter", + Instructions = """ + You are an assistant supporting business users who need to analyze the revene of + customers and products. Use the provided function tools to access the order database + and answer the user's questions. + + Only answer questions related to customer and product revenue. If the user asks + questions not related to this topic, tell her or him that you cannot + answer such questions. + + If the user asks a question that cannot be answered with the provided function tools, + tell her or him that you cannot answer the question because of a lack of access + to the required data. + """, + Tools = { + new CodeInterpreterToolDefinition(), + Functions.GetCustomersFunctionDefinition, + Functions.GetProductsFunctionDefinition, + Functions.GetCustomerProductsRevenueFunctionDefinition, + } +}); + +var thread = await client.CreateThreadAsync(); +while (true) +{ + string[] options = + [ + "I will visit Orlando Gee tomorrow. Give me a revenue breakdown of his revenue per product (absolute revenue and percentages). Also show me his total revenue.", + "Now show me a table with his revenue per year and month.", + "The table is missing some months. Probably because they did not buy anything in those months. Complete the table by adding 0 revenue for all missing months.", + "Show me the data in a table. Include not just percentage values, but also absolute revenue" + ]; + Console.WriteLine("\n"); + for (int i = 0; i < options.Length; i++) + { + Console.WriteLine($"{i + 1}: {options[i]}"); + } + Console.Write("You (just press enter to exit the conversation): "); + var userMessage = Console.ReadLine(); + if (string.IsNullOrEmpty(userMessage)) { break; } + if (int.TryParse(userMessage, out int selection) && selection >= 1 && selection <= options.Length) + { + userMessage = options[selection - 1]; + } + + var run = await client.AddMessageAndRunToCompletion(thread.Value.Id, assistant.Id, userMessage, async functionCall => + { + switch (functionCall.Name) + { + case "getCustomers": + return await Functions.GetCustomers(sqlConnection, JsonHelpers.Deserialize(functionCall.Arguments)!); + case "getProducts": + return await Functions.GetProducts(sqlConnection, JsonHelpers.Deserialize(functionCall.Arguments)!); + case "getCustomerProductsRevenue": + return await Functions.GetCustomerProductsRevenue(sqlConnection, JsonHelpers.Deserialize(functionCall.Arguments)!); + default: + throw new Exception($"Function {functionCall.Name} is not supported"); + } + }); + + if (run.Status == "completed") + { + var lastMessage = await client.GetLatestMessage(thread.Value.Id); + Console.WriteLine($"\n🤖: {lastMessage}"); + } +} diff --git a/labs/Directory.Packages.props b/labs/Directory.Packages.props index f9a30bb..08ee854 100644 --- a/labs/Directory.Packages.props +++ b/labs/Directory.Packages.props @@ -1,10 +1,12 @@ - - true - - - - - - + + true + + + + + + + + \ No newline at end of file diff --git a/labs/HandsOnLabs.sln b/labs/HandsOnLabs.sln index f6ca7b8..db9fb34 100644 --- a/labs/HandsOnLabs.sln +++ b/labs/HandsOnLabs.sln @@ -11,6 +11,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "020-functions", "020-functi EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "FunctionCallingDotNet", "020-functions-dotnet\FunctionCallingDotNet\FunctionCallingDotNet.csproj", "{A71026F5-ACF4-49FF-94C4-73006234E737}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "035-assistants-dotnet", "035-assistants-dotnet\035-assistants-dotnet.csproj", "{4410FBFF-81EE-44E5-82CC-82E619380D87}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -28,6 +30,10 @@ Global {A71026F5-ACF4-49FF-94C4-73006234E737}.Debug|Any CPU.Build.0 = Debug|Any CPU {A71026F5-ACF4-49FF-94C4-73006234E737}.Release|Any CPU.ActiveCfg = Release|Any CPU {A71026F5-ACF4-49FF-94C4-73006234E737}.Release|Any CPU.Build.0 = Release|Any CPU + {4410FBFF-81EE-44E5-82CC-82E619380D87}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4410FBFF-81EE-44E5-82CC-82E619380D87}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4410FBFF-81EE-44E5-82CC-82E619380D87}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4410FBFF-81EE-44E5-82CC-82E619380D87}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(NestedProjects) = preSolution {99F0BCCA-ED5E-4654-937F-C79E7E2130C4} = {663030A0-0A93-4565-BEE2-2F356E49FEFA} diff --git a/readme.md b/readme.md index f1d9a67..1e24ced 100644 --- a/readme.md +++ b/readme.md @@ -6,7 +6,9 @@ This repository contains demos and samples for the [_Microsoft Build: AI Day (Au * OpenAI Chat Completions Basics ([C# and .NET](./labs/010-basics-dotnet/) or [Python](./labs/015-basics-python/)) * [Function Calling with Chat Completions](./labs/020-functions-dotnet/) (C# and .NET) -* [Using Tools with the new _Assistant_ API](./labs/030-assistants-nodejs/) (TypeScript and Node.js) +* Using Tools with the new _Assistant_ API + * [TypeScript and Node.js](./labs/030-assistants-nodejs/) + * [C# and .NET](./labs/035-assistants-dotnet/) * [Embeddings and the RAG model](./labs/040-embeddings-rag-nodejs/) (TypeScript and Node.js) Attendees can decide the complexity level on their own: