From e817b7dc110ddc0749a9968e2a90a049e3d44137 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 12 Nov 2024 08:14:30 +1100 Subject: [PATCH] FEATURE: improve tool support (#904) This re-implements tool support in DiscourseAi::Completions::Llm #generate Previously tool support was always returned via XML and it would be the responsibility of the caller to parse XML New implementation has the endpoints return ToolCall objects. Additionally this simplifies the Llm endpoint interface and gives it more clarity. Llms must implement decode, decode_chunk (for streaming) It is the implementers responsibility to figure out how to decode chunks, base no longer implements. To make this easy we ship a flexible json decoder which is easy to wire up. Also (new) Better debugging for PMs, we now have a next / previous button to see all the Llm messages associated with a PM Token accounting is fixed for vllm (we were not correctly counting tokens) --- .../discourse_ai/ai_bot/bot_controller.rb | 8 + app/models/ai_api_audit_log.rb | 8 + .../ai_api_audit_log_serializer.rb | 4 +- .../components/modal/debug-ai-modal.gjs | 51 ++- config/locales/client.en.yml | 2 + config/routes.rb | 1 + lib/ai_bot/bot.rb | 19 +- lib/ai_bot/personas/persona.rb | 21 +- .../anthropic_message_processor.rb | 118 ++++--- lib/completions/dialects/ollama.rb | 19 +- lib/completions/endpoints/anthropic.rb | 33 +- lib/completions/endpoints/aws_bedrock.rb | 48 ++- lib/completions/endpoints/base.rb | 312 +++++++----------- lib/completions/endpoints/canned_response.rb | 22 +- lib/completions/endpoints/cohere.rb | 89 +++-- lib/completions/endpoints/fake.rb | 52 +-- lib/completions/endpoints/gemini.rb | 108 ++++-- lib/completions/endpoints/hugging_face.rb | 34 +- lib/completions/endpoints/ollama.rb | 90 ++--- lib/completions/endpoints/open_ai.rb | 96 +----- lib/completions/endpoints/samba_nova.rb | 36 +- lib/completions/endpoints/vllm.rb | 51 ++- lib/completions/function_call_normalizer.rb | 113 ------- lib/completions/json_stream_decoder.rb | 48 +++ lib/completions/open_ai_message_processor.rb | 103 ++++++ lib/completions/tool_call.rb | 29 ++ lib/completions/xml_tool_processor.rb | 124 +++++++ .../completions/endpoints/anthropic_spec.rb | 52 +-- .../completions/endpoints/aws_bedrock_spec.rb | 49 ++- spec/lib/completions/endpoints/cohere_spec.rb | 25 +- .../endpoints/endpoint_compliance.rb | 29 +- spec/lib/completions/endpoints/gemini_spec.rb | 89 ++++- spec/lib/completions/endpoints/ollama_spec.rb | 2 +- .../lib/completions/endpoints/open_ai_spec.rb | 119 +++---- .../completions/endpoints/samba_nova_spec.rb | 9 +- spec/lib/completions/endpoints/vllm_spec.rb | 128 ++++++- .../function_call_normalizer_spec.rb | 182 ---------- .../completions/json_stream_decoder_spec.rb | 47 +++ .../completions/xml_tool_processor_spec.rb | 188 +++++++++++ .../modules/ai_bot/personas/persona_spec.rb | 193 ++++++----- spec/lib/modules/ai_bot/playground_spec.rb | 157 ++++----- .../admin/ai_personas_controller_spec.rb | 14 +- spec/requests/ai_bot/bot_controller_spec.rb | 56 +++- 43 files changed, 1685 insertions(+), 1293 deletions(-) delete mode 100644 lib/completions/function_call_normalizer.rb create mode 100644 lib/completions/json_stream_decoder.rb create mode 100644 lib/completions/open_ai_message_processor.rb create mode 100644 lib/completions/tool_call.rb create mode 100644 lib/completions/xml_tool_processor.rb delete mode 100644 spec/lib/completions/function_call_normalizer_spec.rb create mode 100644 spec/lib/completions/json_stream_decoder_spec.rb create mode 100644 spec/lib/completions/xml_tool_processor_spec.rb diff --git a/app/controllers/discourse_ai/ai_bot/bot_controller.rb b/app/controllers/discourse_ai/ai_bot/bot_controller.rb index e5d5bcf00..5ea13795e 100644 --- a/app/controllers/discourse_ai/ai_bot/bot_controller.rb +++ b/app/controllers/discourse_ai/ai_bot/bot_controller.rb @@ -6,6 +6,14 @@ class BotController < ::ApplicationController requires_plugin ::DiscourseAi::PLUGIN_NAME requires_login + def show_debug_info_by_id + log = AiApiAuditLog.find(params[:id]) + raise Discourse::NotFound if !log.topic + + guardian.ensure_can_debug_ai_bot_conversation!(log.topic) + render json: AiApiAuditLogSerializer.new(log, root: false), status: 200 + end + def show_debug_info post = Post.find(params[:post_id]) guardian.ensure_can_debug_ai_bot_conversation!(post) diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index 2fa9f5c37..2fa0a2140 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -14,6 +14,14 @@ module Provider Ollama = 7 SambaNova = 8 end + + def next_log_id + self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first + end + + def prev_log_id + self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first + end end # == Schema Information diff --git a/app/serializers/ai_api_audit_log_serializer.rb b/app/serializers/ai_api_audit_log_serializer.rb index 0c438a7b4..eeb3843ac 100644 --- a/app/serializers/ai_api_audit_log_serializer.rb +++ b/app/serializers/ai_api_audit_log_serializer.rb @@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer :post_id, :feature_name, :language_model, - :created_at + :created_at, + :prev_log_id, + :next_log_id end diff --git a/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs b/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs index 5d0cdf692..c21e8df37 100644 --- a/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs +++ b/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs @@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template"; import DButton from "discourse/components/d-button"; import DModal from "discourse/components/d-modal"; import { ajax } from "discourse/lib/ajax"; +import { popupAjaxError } from "discourse/lib/ajax-error"; import { clipboardCopy, escapeExpression } from "discourse/lib/utilities"; import i18n from "discourse-common/helpers/i18n"; import discourseLater from "discourse-common/lib/later"; @@ -63,6 +64,28 @@ export default class DebugAiModal extends Component { this.copy(this.info.raw_response_payload); } + async loadLog(logId) { + try { + await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then( + (result) => { + this.info = result; + } + ); + } catch (e) { + popupAjaxError(e); + } + } + + @action + prevLog() { + this.loadLog(this.info.prev_log_id); + } + + @action + nextLog() { + this.loadLog(this.info.next_log_id); + } + copy(text) { clipboardCopy(text); this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared"); @@ -73,11 +96,13 @@ export default class DebugAiModal extends Component { } loadApiRequestInfo() { - ajax( - `/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json` - ).then((result) => { - this.info = result; - }); + ajax(`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`) + .then((result) => { + this.info = result; + }) + .catch((e) => { + popupAjaxError(e); + }); } get requestActive() { @@ -147,6 +172,22 @@ export default class DebugAiModal extends Component { @action={{this.copyResponse}} @label="discourse_ai.ai_bot.debug_ai_modal.copy_response" /> + {{#if this.info.prev_log_id}} + + {{/if}} + {{#if this.info.next_log_id}} + + {{/if}} {{this.justCopiedText}} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 2d5179461..82898c91d 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -415,6 +415,8 @@ en: response_tokens: "Response tokens:" request: "Request" response: "Response" + next_log: "Next" + previous_log: "Previous" share_full_topic_modal: title: "Share Conversation Publicly" diff --git a/config/routes.rb b/config/routes.rb index a5c009ff6..322e67ce1 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -22,6 +22,7 @@ scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do get "bot-username" => "bot#show_bot_username" get "post/:post_id/show-debug-info" => "bot#show_debug_info" + get "show-debug-info/:id" => "bot#show_debug_info_by_id" post "post/:post_id/stop-streaming" => "bot#stop_streaming_response" end diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 834ae059f..b965b1f60 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -100,6 +100,7 @@ def reply(context, &update_blk) llm_kwargs[:top_p] = persona.top_p if persona.top_p needs_newlines = false + tools_ran = 0 while total_completions <= MAX_COMPLETIONS && ongoing_chain tool_found = false @@ -107,9 +108,10 @@ def reply(context, &update_blk) result = llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel| - tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context) + tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context) + tool = nil if tools_ran >= MAX_TOOLS - if (tools.present?) + if tool.present? tool_found = true # a bit hacky, but extra newlines do no harm if needs_newlines @@ -117,13 +119,16 @@ def reply(context, &update_blk) needs_newlines = false end - tools[0..MAX_TOOLS].each do |tool| - process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) - ongoing_chain &&= tool.chain_next_response? - end + process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) + tools_ran += 1 + ongoing_chain &&= tool.chain_next_response? else needs_newlines = true - update_blk.call(partial, cancel) + if partial.is_a?(DiscourseAi::Completions::ToolCall) + Rails.logger.warn("DiscourseAi: Tool not found: #{partial.name}") + else + update_blk.call(partial, cancel) + end end end diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 73224808a..63255a172 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -199,23 +199,16 @@ def craft_prompt(context, llm: nil) prompt end - def find_tools(partial, bot_user:, llm:, context:) - return [] if !partial.include?("") - - parsed_function = Nokogiri::HTML5.fragment(partial) - parsed_function - .css("invoke") - .map do |fragment| - tool_instance(fragment, bot_user: bot_user, llm: llm, context: context) - end - .compact + def find_tool(partial, bot_user:, llm:, context:) + return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall) + tool_instance(partial, bot_user: bot_user, llm: llm, context: context) end protected - def tool_instance(parsed_function, bot_user:, llm:, context:) - function_id = parsed_function.at("tool_id")&.text - function_name = parsed_function.at("tool_name")&.text + def tool_instance(tool_call, bot_user:, llm:, context:) + function_id = tool_call.id + function_name = tool_call.name return nil if function_name.nil? tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name } @@ -224,7 +217,7 @@ def tool_instance(parsed_function, bot_user:, llm:, context:) arguments = {} tool_klass.signature[:parameters].to_a.each do |param| name = param[:name] - value = parsed_function.at(name)&.text + value = tool_call.parameters[name.to_sym] if param[:type] == "array" && value value = diff --git a/lib/completions/anthropic_message_processor.rb b/lib/completions/anthropic_message_processor.rb index 1d1516fa4..5d5602ef5 100644 --- a/lib/completions/anthropic_message_processor.rb +++ b/lib/completions/anthropic_message_processor.rb @@ -13,6 +13,11 @@ def initialize(name, id) def append(json) @raw_json << json end + + def to_tool_call + parameters = JSON.parse(raw_json, symbolize_names: true) + DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters) + end end attr_reader :tool_calls, :input_tokens, :output_tokens @@ -20,80 +25,69 @@ def append(json) def initialize(streaming_mode:) @streaming_mode = streaming_mode @tool_calls = [] + @current_tool_call = nil end - def to_xml_tool_calls(function_buffer) - return function_buffer if @tool_calls.blank? - - function_buffer = Nokogiri::HTML5.fragment(<<~TEXT) - - - TEXT - - @tool_calls.each do |tool_call| - node = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse( - DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n", - ), - ) - - params = JSON.parse(tool_call.raw_json, symbolize_names: true) - xml = - params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}" }.join("\n") + def to_tool_calls + @tool_calls.map { |tool_call| tool_call.to_tool_call } + end - node.at("tool_name").content = tool_call.name - node.at("tool_id").content = tool_call.id - node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present? + def process_streamed_message(parsed) + result = nil + if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use" + tool_name = parsed.dig(:content_block, :name) + tool_id = parsed.dig(:content_block, :id) + result = @current_tool_call.to_tool_call if @current_tool_call + @current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name + elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" + if @current_tool_call + tool_delta = parsed.dig(:delta, :partial_json).to_s + @current_tool_call.append(tool_delta) + else + result = parsed.dig(:delta, :text).to_s + end + elsif parsed[:type] == "content_block_stop" + if @current_tool_call + result = @current_tool_call.to_tool_call + @current_tool_call = nil + end + elsif parsed[:type] == "message_start" + @input_tokens = parsed.dig(:message, :usage, :input_tokens) + elsif parsed[:type] == "message_delta" + @output_tokens = + parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens) + elsif parsed[:type] == "message_stop" + # bedrock has this ... + if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym) + @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens + @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens + end end - - function_buffer + result end def process_message(payload) result = "" - parsed = JSON.parse(payload, symbolize_names: true) + parsed = payload + parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String) - if @streaming_mode - if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use" - tool_name = parsed.dig(:content_block, :name) - tool_id = parsed.dig(:content_block, :id) - @tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name - elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" - if @tool_calls.present? - result = parsed.dig(:delta, :partial_json).to_s - @tool_calls.last.append(result) - else - result = parsed.dig(:delta, :text).to_s - end - elsif parsed[:type] == "message_start" - @input_tokens = parsed.dig(:message, :usage, :input_tokens) - elsif parsed[:type] == "message_delta" - @output_tokens = - parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens) - elsif parsed[:type] == "message_stop" - # bedrock has this ... - if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym) - @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens - @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens - end - end - else - content = parsed.dig(:content) - if content.is_a?(Array) - tool_call = content.find { |c| c[:type] == "tool_use" } - if tool_call - @tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id]) - @tool_calls.last.append(tool_call[:input].to_json) - else - result = parsed.dig(:content, 0, :text).to_s + content = parsed.dig(:content) + if content.is_a?(Array) + result = + content.map do |data| + if data[:type] == "tool_use" + call = AnthropicToolCall.new(data[:name], data[:id]) + call.append(data[:input].to_json) + call.to_tool_call + else + data[:text] + end end - end - - @input_tokens = parsed.dig(:usage, :input_tokens) - @output_tokens = parsed.dig(:usage, :output_tokens) end + @input_tokens = parsed.dig(:usage, :input_tokens) + @output_tokens = parsed.dig(:usage, :output_tokens) + result end end diff --git a/lib/completions/dialects/ollama.rb b/lib/completions/dialects/ollama.rb index 541d0e733..3a32e5927 100644 --- a/lib/completions/dialects/ollama.rb +++ b/lib/completions/dialects/ollama.rb @@ -63,8 +63,23 @@ def enable_native_tool? def user_msg(msg) user_message = { role: "user", content: msg[:content] } - # TODO: Add support for user messages with empbeded user ids - # TODO: Add support for user messages with attachments + encoded_uploads = prompt.encoded_uploads(msg) + if encoded_uploads.present? + images = + encoded_uploads + .map do |upload| + if upload[:mime_type].start_with?("image/") + upload[:base64] + else + nil + end + end + .compact + + user_message[:images] = images if images.present? + end + + # TODO: Add support for user messages with embedded user ids user_message end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 44762b886..6576ef3bc 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -63,6 +63,10 @@ def model_uri URI(llm_model.url) end + def xml_tools_enabled? + !@native_tool_support + end + def prepare_payload(prompt, model_params, dialect) @native_tool_support = dialect.native_tool_support? @@ -90,35 +94,34 @@ def prepare_request(payload) Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def processor - @processor ||= - DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) + def decode_chunk(partial_data) + @decoder ||= JsonStreamDecoder.new + (@decoder << partial_data) + .map { |parsed_json| processor.process_streamed_message(parsed_json) } + .compact end - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - processor.to_xml_tool_calls(function_buffer) if !partial + def decode(response_data) + processor.process_message(response_data) end - def extract_completion_from(response_raw) - processor.process_message(response_raw) + def processor + @processor ||= + DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) end def has_tool?(_response_data) processor.tool_calls.present? end + def tool_calls + processor.to_tool_calls + end + def final_log_update(log) log.request_tokens = processor.input_tokens if processor.input_tokens log.response_tokens = processor.output_tokens if processor.output_tokens end - - def native_tool_support? - @native_tool_support - end - - def partials_from(decoded_chunk) - decoded_chunk.split("\n").map { |line| line.split("data: ", 2)[1] }.compact - end end end end diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index f3146c2d7..c17a051fe 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -117,7 +117,24 @@ def prepare_request(payload) end end - def decode(chunk) + def decode_chunk(partial_data) + bedrock_decode(partial_data) + .map do |decoded_partial_data| + @raw_response ||= +"" + @raw_response << decoded_partial_data + @raw_response << "\n" + + parsed_json = JSON.parse(decoded_partial_data, symbolize_names: true) + processor.process_streamed_message(parsed_json) + end + .compact + end + + def decode(response_data) + processor.process_message(response_data) + end + + def bedrock_decode(chunk) @decoder ||= Aws::EventStream::Decoder.new decoded, _done = @decoder.decode_chunk(chunk) @@ -147,12 +164,13 @@ def decode(chunk) Aws::EventStream::Errors::MessageChecksumError, Aws::EventStream::Errors::PreludeChecksumError => e Rails.logger.error("#{self.class.name}: #{e.message}") - nil + [] end def final_log_update(log) log.request_tokens = processor.input_tokens if processor.input_tokens log.response_tokens = processor.output_tokens if processor.output_tokens + log.raw_response_payload = @raw_response end def processor @@ -160,30 +178,8 @@ def processor DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) end - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - processor.to_xml_tool_calls(function_buffer) if !partial - end - - def extract_completion_from(response_raw) - processor.process_message(response_raw) - end - - def has_tool?(_response_data) - processor.tool_calls.present? - end - - def partials_from(decoded_chunks) - decoded_chunks - end - - def native_tool_support? - @native_tool_support - end - - def chunk_to_string(chunk) - joined = +chunk.join("\n") - joined << "\n" if joined.length > 0 - joined + def xml_tools_enabled? + !@native_tool_support end end end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index a0405b42c..c78fcdd98 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -40,10 +40,6 @@ def initialize(llm_model) @llm_model = llm_model end - def native_tool_support? - false - end - def use_ssl? if model_uri&.scheme.present? model_uri.scheme == "https" @@ -64,22 +60,10 @@ def perform_completion!( feature_context: nil, &blk ) - allow_tools = dialect.prompt.has_tools? model_params = normalize_model_params(model_params) orig_blk = blk @streaming_mode = block_given? - to_strip = xml_tags_to_strip(dialect) - @xml_stripper = - DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present? - - if @streaming_mode && @xml_stripper - blk = - lambda do |partial, cancel| - partial = @xml_stripper << partial - orig_blk.call(partial, cancel) if partial - end - end prompt = dialect.translate @@ -108,177 +92,91 @@ def perform_completion!( raise CompletionFailed, response.body end + xml_tool_processor = XmlToolProcessor.new if xml_tools_enabled? && + dialect.prompt.has_tools? + + to_strip = xml_tags_to_strip(dialect) + xml_stripper = + DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present? + + if @streaming_mode && xml_stripper + blk = + lambda do |partial, cancel| + partial = xml_stripper << partial if partial.is_a?(String) + orig_blk.call(partial, cancel) if partial + end + end + log = - AiApiAuditLog.new( + start_log( provider_id: provider_id, - user_id: user&.id, - raw_request_payload: request_body, - request_tokens: prompt_size(prompt), - topic_id: dialect.prompt.topic_id, - post_id: dialect.prompt.post_id, + request_body: request_body, + dialect: dialect, + prompt: prompt, + user: user, feature_name: feature_name, - language_model: llm_model.name, - feature_context: feature_context.present? ? feature_context.as_json : nil, + feature_context: feature_context, ) if !@streaming_mode - response_raw = response.read_body - response_data = extract_completion_from(response_raw) - partials_raw = response_data.to_s - - if native_tool_support? - if allow_tools && has_tool?(response_data) - function_buffer = build_buffer # Nokogiri document - function_buffer = - add_to_function_buffer(function_buffer, payload: response_data) - FunctionCallNormalizer.normalize_function_ids!(function_buffer) - - response_data = +function_buffer.at("function_calls").to_s - response_data << "\n" - end - else - if allow_tools - response_data, function_calls = FunctionCallNormalizer.normalize(response_data) - response_data = function_calls if function_calls.present? - end - end - - return response_data + return( + non_streaming_response( + response: response, + xml_tool_processor: xml_tool_processor, + xml_stripper: xml_stripper, + partials_raw: partials_raw, + response_raw: response_raw, + ) + ) end - has_tool = false - begin cancelled = false cancel = -> { cancelled = true } - - wrapped_blk = ->(partial, inner_cancel) do - response_data << partial - blk.call(partial, inner_cancel) + if cancelled + http.finish + break end - normalizer = FunctionCallNormalizer.new(wrapped_blk, cancel) - - leftover = "" - function_buffer = build_buffer # Nokogiri document - prev_processed_partials = 0 - response.read_body do |chunk| - if cancelled - http.finish - break - end - - decoded_chunk = decode(chunk) - if decoded_chunk.nil? - raise CompletionFailed, "#{self.class.name}: Failed to decode LLM completion" - end - response_raw << chunk_to_string(decoded_chunk) - - if decoded_chunk.is_a?(String) - redo_chunk = leftover + decoded_chunk - else - # custom implementation for endpoint - # no implicit leftover support - redo_chunk = decoded_chunk - end - - raw_partials = partials_from(redo_chunk) - - raw_partials = - raw_partials[prev_processed_partials..-1] if prev_processed_partials > 0 - - if raw_partials.blank? || (raw_partials.size == 1 && raw_partials.first.blank?) - leftover = redo_chunk - next - end - - json_error = false - - raw_partials.each do |raw_partial| - json_error = false - prev_processed_partials += 1 - - next if cancelled - next if raw_partial.blank? - - begin - partial = extract_completion_from(raw_partial) - next if partial.nil? - # empty vs blank... we still accept " " - next if response_data.empty? && partial.empty? - partials_raw << partial.to_s - - if native_tool_support? - # Stop streaming the response as soon as you find a tool. - # We'll buffer and yield it later. - has_tool = true if allow_tools && has_tool?(partials_raw) - - if has_tool - function_buffer = - add_to_function_buffer(function_buffer, partial: partial) - else - response_data << partial - blk.call(partial, cancel) if partial - end - else - if allow_tools - normalizer << partial - else - response_data << partial - blk.call(partial, cancel) if partial - end + response_raw << chunk + decode_chunk(chunk).each do |partial| + partials_raw << partial.to_s + response_data << partial if partial.is_a?(String) + partials = [partial] + if xml_tool_processor && partial.is_a?(String) + partials = (xml_tool_processor << partial) + if xml_tool_processor.should_cancel? + cancel.call + break end - rescue JSON::ParserError - leftover = redo_chunk - json_error = true end + partials.each { |inner_partial| blk.call(inner_partial, cancel) } end - - if json_error - prev_processed_partials -= 1 - else - leftover = "" - end - - prev_processed_partials = 0 if leftover.blank? end rescue IOError, StandardError raise if !cancelled end - - has_tool ||= has_tool?(partials_raw) - # Once we have the full response, try to return the tool as a XML doc. - if has_tool && native_tool_support? - function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw) - - if function_buffer.at("tool_name").text.present? - FunctionCallNormalizer.normalize_function_ids!(function_buffer) - - invocation = +function_buffer.at("function_calls").to_s - invocation << "\n" - - response_data << invocation - blk.call(invocation, cancel) + if xml_stripper + stripped = xml_stripper.finish + if stripped.present? + response_data << stripped + result = [] + result = (xml_tool_processor << stripped) if xml_tool_processor + result.each { |partial| blk.call(partial, cancel) } end end - - if !native_tool_support? && function_calls = normalizer.function_calls - response_data << function_calls - blk.call(function_calls, cancel) - end - - if @xml_stripper - leftover = @xml_stripper.finish - orig_blk.call(leftover, cancel) if leftover.present? + if xml_tool_processor + xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) } end - + decode_chunk_finish.each { |partial| blk.call(partial, cancel) } return response_data ensure if log log.raw_response_payload = response_raw - log.response_tokens = tokenizer.size(partials_raw) final_log_update(log) + + log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank? log.save! if Rails.env.development? @@ -330,15 +228,15 @@ def prepare_request(_payload) raise NotImplementedError end - def extract_completion_from(_response_raw) + def decode(_response_raw) raise NotImplementedError end - def decode(chunk) - chunk + def decode_chunk_finish + [] end - def partials_from(_decoded_chunk) + def decode_chunk(_chunk) raise NotImplementedError end @@ -346,49 +244,73 @@ def extract_prompt_for_tokenizer(prompt) prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") end - def build_buffer - Nokogiri::HTML5.fragment(<<~TEXT) - - #{noop_function_call_text} - - TEXT + def xml_tools_enabled? + raise NotImplementedError end - def self.noop_function_call_text - (<<~TEXT).strip - - - - - - - TEXT - end + private - def noop_function_call_text - self.class.noop_function_call_text + def start_log( + provider_id:, + request_body:, + dialect:, + prompt:, + user:, + feature_name:, + feature_context: + ) + AiApiAuditLog.new( + provider_id: provider_id, + user_id: user&.id, + raw_request_payload: request_body, + request_tokens: prompt_size(prompt), + topic_id: dialect.prompt.topic_id, + post_id: dialect.prompt.post_id, + feature_name: feature_name, + language_model: llm_model.name, + feature_context: feature_context.present? ? feature_context.as_json : nil, + ) end - def has_tool?(response) - response.include?("") - end + def non_streaming_response( + response:, + xml_tool_processor:, + xml_stripper:, + partials_raw:, + response_raw: + ) + response_raw << response.read_body + response_data = decode(response_raw) - def chunk_to_string(chunk) - if chunk.is_a?(String) - chunk - else - chunk.to_s + response_data.each { |partial| partials_raw << partial.to_s } + + if xml_tool_processor + response_data.each do |partial| + processed = (xml_tool_processor << partial) + processed << xml_tool_processor.finish + response_data = [] + processed.flatten.compact.each { |inner| response_data << inner } + end end - end - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - if payload&.include?("") - matches = payload.match(%r{.*}m) - function_buffer = - Nokogiri::HTML5.fragment(matches[0] + "\n") if matches + if xml_stripper + response_data.map! do |partial| + stripped = (xml_stripper << partial) if partial.is_a?(String) + if stripped.present? + stripped + else + partial + end + end + response_data << xml_stripper.finish end - function_buffer + response_data.reject!(&:blank?) + + # this is to keep stuff backwards compatible + response_data = response_data.first if response_data.length == 1 + + response_data end end end diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index eaef21da2..bd3ae4eab 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -45,17 +45,21 @@ def perform_completion!( cancel_fn = lambda { cancelled = true } # We buffer and return tool invocations in one go. - if is_tool?(response) - yield(response, cancel_fn) - else - response.each_char do |char| - break if cancelled - yield(char, cancel_fn) + as_array = response.is_a?(Array) ? response : [response] + as_array.each do |response| + if is_tool?(response) + yield(response, cancel_fn) + else + response.each_char do |char| + break if cancelled + yield(char, cancel_fn) + end end end - else - response end + + response = response.first if response.is_a?(Array) && response.length == 1 + response end def tokenizer @@ -65,7 +69,7 @@ def tokenizer private def is_tool?(response) - Nokogiri::HTML5.fragment(response).at("function_calls").present? + response.is_a?(DiscourseAi::Completions::ToolCall) end end end diff --git a/lib/completions/endpoints/cohere.rb b/lib/completions/endpoints/cohere.rb index 180c27c83..258062a1f 100644 --- a/lib/completions/endpoints/cohere.rb +++ b/lib/completions/endpoints/cohere.rb @@ -49,6 +49,47 @@ def prepare_request(payload) Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end + def decode(response_raw) + rval = [] + + parsed = JSON.parse(response_raw, symbolize_names: true) + + text = parsed[:text] + rval << parsed[:text] if !text.to_s.empty? # also allow " " + + # TODO tool calls + + update_usage(parsed) + + rval + end + + def decode_chunk(chunk) + @tool_idx ||= -1 + @json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/) + (@json_decoder << chunk) + .map do |parsed| + update_usage(parsed) + rval = [] + + rval << parsed[:text] if !parsed[:text].to_s.empty? + + if tool_calls = parsed[:tool_calls] + tool_calls&.each do |tool_call| + @tool_idx += 1 + tool_name = tool_call[:name] + tool_params = tool_call[:parameters] + tool_id = "tool_#{@tool_idx}" + rval << ToolCall.new(id: tool_id, name: tool_name, parameters: tool_params) + end + end + + rval + end + .flatten + .compact + end + def extract_completion_from(response_raw) parsed = JSON.parse(response_raw, symbolize_names: true) @@ -77,36 +118,8 @@ def extract_completion_from(response_raw) end end - def has_tool?(_ignored) - @has_tool - end - - def native_tool_support? - true - end - - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - if partial - tools = JSON.parse(partial) - tools.each do |tool| - name = tool["name"] - parameters = tool["parameters"] - xml_params = parameters.map { |k, v| "<#{k}>#{v}\n" }.join - - current_function = function_buffer.at("invoke") - if current_function.nil? || current_function.at("tool_name").content.present? - current_function = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"), - ) - end - - current_function.at("tool_name").content = name == "search_local" ? "search" : name - current_function.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(xml_params) - end - end - function_buffer + def xml_tools_enabled? + false end def final_log_update(log) @@ -114,10 +127,6 @@ def final_log_update(log) log.response_tokens = @output_tokens if @output_tokens end - def partials_from(decoded_chunk) - decoded_chunk.split("\n").compact - end - def extract_prompt_for_tokenizer(prompt) text = +"" if prompt[:chat_history] @@ -131,6 +140,18 @@ def extract_prompt_for_tokenizer(prompt) text end + + private + + def update_usage(parsed) + input_tokens = parsed.dig(:meta, :billed_units, :input_tokens) + input_tokens ||= parsed.dig(:response, :meta, :billed_units, :input_tokens) + @input_tokens = input_tokens if input_tokens.present? + + output_tokens = parsed.dig(:meta, :billed_units, :output_tokens) + output_tokens ||= parsed.dig(:response, :meta, :billed_units, :output_tokens) + @output_tokens = output_tokens if output_tokens.present? + end end end end diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb index a51ff3ac2..15cc254d4 100644 --- a/lib/completions/endpoints/fake.rb +++ b/lib/completions/endpoints/fake.rb @@ -133,31 +133,35 @@ def perform_completion!( content = content.shift if content.is_a?(Array) if block_given? - split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort - indexes = [0, *split_indices, content.length] - - original_content = content - content = +"" - - cancel = false - cancel_proc = -> { cancel = true } - - i = 0 - indexes - .each_cons(2) - .map { |start, finish| original_content[start...finish] } - .each do |chunk| - break if cancel - if self.class.delays.present? && - (delay = self.class.delays[i % self.class.delays.length]) - sleep(delay) - i += 1 + if content.is_a?(DiscourseAi::Completions::ToolCall) + yield(content, -> {}) + else + split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort + indexes = [0, *split_indices, content.length] + + original_content = content + content = +"" + + cancel = false + cancel_proc = -> { cancel = true } + + i = 0 + indexes + .each_cons(2) + .map { |start, finish| original_content[start...finish] } + .each do |chunk| + break if cancel + if self.class.delays.present? && + (delay = self.class.delays[i % self.class.delays.length]) + sleep(delay) + i += 1 + end + break if cancel + + content << chunk + yield(chunk, cancel_proc) end - break if cancel - - content << chunk - yield(chunk, cancel_proc) - end + end end content diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index ddf607b2d..2450dc99e 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -103,15 +103,7 @@ def extract_completion_from(response_raw) end end - def partials_from(decoded_chunk) - decoded_chunk - end - - def chunk_to_string(chunk) - chunk.to_s - end - - class Decoder + class GeminiStreamingDecoder def initialize @buffer = +"" end @@ -151,43 +143,87 @@ def decode(str) end def decode(chunk) - @decoder ||= Decoder.new - @decoder.decode(chunk) + json = JSON.parse(chunk, symbolize_names: true) + idx = -1 + json + .dig(:candidates, 0, :content, :parts) + .map do |part| + if part[:functionCall] + idx += 1 + ToolCall.new( + id: "tool_#{idx}", + name: part[:functionCall][:name], + parameters: part[:functionCall][:args], + ) + else + part = part[:text] + if part != "" + part + else + nil + end + end + end end - def extract_prompt_for_tokenizer(prompt) - prompt.to_s + def decode_chunk(chunk) + @tool_index ||= -1 + + streaming_decoder + .decode(chunk) + .map do |parsed| + update_usage(parsed) + parsed + .dig(:candidates, 0, :content, :parts) + .map do |part| + if part[:text] + part = part[:text] + if part != "" + part + else + nil + end + elsif part[:functionCall] + @tool_index += 1 + ToolCall.new( + id: "tool_#{@tool_index}", + name: part[:functionCall][:name], + parameters: part[:functionCall][:args], + ) + end + end + end + .flatten + .compact end - def has_tool?(_response_data) - @has_function_call + def update_usage(parsed) + usage = parsed.dig(:usageMetadata) + if usage + if prompt_token_count = usage[:promptTokenCount] + @prompt_token_count = prompt_token_count + end + if candidate_token_count = usage[:candidatesTokenCount] + @candidate_token_count = candidate_token_count + end + end end - def native_tool_support? - true + def final_log_update(log) + log.request_tokens = @prompt_token_count if @prompt_token_count + log.response_tokens = @candidate_token_count if @candidate_token_count end - def add_to_function_buffer(function_buffer, payload: nil, partial: nil) - if @streaming_mode - return function_buffer if !partial - else - partial = payload - end - - function_buffer.at("tool_name").content = partial[:name] if partial[:name].present? - - if partial[:args] - argument_fragments = - partial[:args].reduce(+"") do |memo, (arg_name, value)| - memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}" - end - argument_fragments << "\n" + def streaming_decoder + @decoder ||= GeminiStreamingDecoder.new + end - function_buffer.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) - end + def extract_prompt_for_tokenizer(prompt) + prompt.to_s + end - function_buffer + def xml_tools_enabled? + false end end end diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb index bd7edc063..b0b14722b 100644 --- a/lib/completions/endpoints/hugging_face.rb +++ b/lib/completions/endpoints/hugging_face.rb @@ -59,22 +59,30 @@ def prepare_request(payload) Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - # half a line sent here - return if !parsed - - response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) + def xml_tools_enabled? + true + end - response_h.dig(:content) + def decode(response_raw) + parsed = JSON.parse(response_raw, symbolize_names: true) + text = parsed.dig(:choices, 0, :message, :content) + if text.to_s.empty? + [""] + else + [text] + end end - def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data:", 2)[1] - data&.squish == "[DONE]" ? nil : data + def decode_chunk(chunk) + @json_decoder ||= JsonStreamDecoder.new + (@json_decoder << chunk) + .map do |parsed| + text = parsed.dig(:choices, 0, :delta, :content) + if text.to_s.empty? + nil + else + text + end end .compact end diff --git a/lib/completions/endpoints/ollama.rb b/lib/completions/endpoints/ollama.rb index cc58006a3..dd4ca2c7e 100644 --- a/lib/completions/endpoints/ollama.rb +++ b/lib/completions/endpoints/ollama.rb @@ -37,12 +37,8 @@ def model_uri URI(llm_model.url) end - def native_tool_support? - @native_tool_support - end - - def has_tool?(_response_data) - @has_function_call + def xml_tools_enabled? + !@native_tool_support end def prepare_payload(prompt, model_params, dialect) @@ -67,74 +63,30 @@ def prepare_request(payload) Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def partials_from(decoded_chunk) - decoded_chunk.split("\n").compact + def decode_chunk(chunk) + # Native tool calls are not working right in streaming mode, use XML + @json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/) + (@json_decoder << chunk).map { |parsed| parsed.dig(:message, :content) }.compact end - def extract_completion_from(response_raw) + def decode(response_raw) + rval = [] parsed = JSON.parse(response_raw, symbolize_names: true) - return if !parsed - - response_h = parsed.dig(:message) - - @has_function_call ||= response_h.dig(:tool_calls).present? - @has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content) - end - - def add_to_function_buffer(function_buffer, payload: nil, partial: nil) - @args_buffer ||= +"" - - if @streaming_mode - return function_buffer if !partial - else - partial = payload - end - - f_name = partial.dig(:function, :name) - - @current_function ||= function_buffer.at("invoke") - - if f_name - current_name = function_buffer.at("tool_name").content - - if current_name.blank? - # first call - else - # we have a previous function, so we need to add a noop - @args_buffer = +"" - @current_function = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"), - ) - end - end - - @current_function.at("tool_name").content = f_name if f_name - @current_function.at("tool_id").content = partial[:id] if partial[:id] - - args = partial.dig(:function, :arguments) - - # allow for SPACE within arguments - if args && args != "" - @args_buffer << args.to_json - - begin - json_args = JSON.parse(@args_buffer, symbolize_names: true) - - argument_fragments = - json_args.reduce(+"") do |memo, (arg_name, value)| - memo << "\n<#{arg_name}>#{value}" - end - argument_fragments << "\n" - - @current_function.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) - rescue JSON::ParserError - return function_buffer + content = parsed.dig(:message, :content) + rval << content if !content.to_s.empty? + + idx = -1 + parsed + .dig(:message, :tool_calls) + &.each do |tool_call| + idx += 1 + id = "tool_#{idx}" + name = tool_call.dig(:function, :name) + args = tool_call.dig(:function, :arguments) + rval << ToolCall.new(id: id, name: name, parameters: args) end - end - function_buffer + rval end end end diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 92315ed5c..a185a840a 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -93,98 +93,34 @@ def prepare_request(payload) end def final_log_update(log) - log.request_tokens = @prompt_tokens if @prompt_tokens - log.response_tokens = @completion_tokens if @completion_tokens + log.request_tokens = processor.prompt_tokens if processor.prompt_tokens + log.response_tokens = processor.completion_tokens if processor.completion_tokens end - def extract_completion_from(response_raw) - json = JSON.parse(response_raw, symbolize_names: true) - - if @streaming_mode - @prompt_tokens ||= json.dig(:usage, :prompt_tokens) - @completion_tokens ||= json.dig(:usage, :completion_tokens) - end - - parsed = json.dig(:choices, 0) - return if !parsed - - response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) - @has_function_call ||= response_h.dig(:tool_calls).present? - @has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content) + def decode(response_raw) + processor.process_message(JSON.parse(response_raw, symbolize_names: true)) end - def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data: ", 2)[1] - data == "[DONE]" ? nil : data - end + def decode_chunk(chunk) + @decoder ||= JsonStreamDecoder.new + (@decoder << chunk) + .map { |parsed_json| processor.process_streamed_message(parsed_json) } + .flatten .compact end - def has_tool?(_response_data) - @has_function_call + def decode_chunk_finish + @processor.finish end - def native_tool_support? - true + def xml_tools_enabled? + false end - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - if @streaming_mode - return function_buffer if !partial - else - partial = payload - end - - @args_buffer ||= +"" - - f_name = partial.dig(:function, :name) - - @current_function ||= function_buffer.at("invoke") - - if f_name - current_name = function_buffer.at("tool_name").content - - if current_name.blank? - # first call - else - # we have a previous function, so we need to add a noop - @args_buffer = +"" - @current_function = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"), - ) - end - end - - @current_function.at("tool_name").content = f_name if f_name - @current_function.at("tool_id").content = partial[:id] if partial[:id] - - args = partial.dig(:function, :arguments) - - # allow for SPACE within arguments - if args && args != "" - @args_buffer << args - - begin - json_args = JSON.parse(@args_buffer, symbolize_names: true) - - argument_fragments = - json_args.reduce(+"") do |memo, (arg_name, value)| - memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}" - end - argument_fragments << "\n" - - @current_function.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) - rescue JSON::ParserError - return function_buffer - end - end + private - function_buffer + def processor + @processor ||= OpenAiMessageProcessor.new end end end diff --git a/lib/completions/endpoints/samba_nova.rb b/lib/completions/endpoints/samba_nova.rb index ccb883cc8..cc81e786b 100644 --- a/lib/completions/endpoints/samba_nova.rb +++ b/lib/completions/endpoints/samba_nova.rb @@ -55,27 +55,31 @@ def final_log_update(log) log.response_tokens = @completion_tokens if @completion_tokens end - def extract_completion_from(response_raw) - json = JSON.parse(response_raw, symbolize_names: true) + def xml_tools_enabled? + true + end - if @streaming_mode - @prompt_tokens ||= json.dig(:usage, :prompt_tokens) - @completion_tokens ||= json.dig(:usage, :completion_tokens) - end + def decode(response_raw) + json = JSON.parse(response_raw, symbolize_names: true) + [json.dig(:choices, 0, :message, :content)] + end - parsed = json.dig(:choices, 0) - return if !parsed + def decode_chunk(chunk) + @json_decoder ||= JsonStreamDecoder.new + (@json_decoder << chunk) + .map do |json| + text = json.dig(:choices, 0, :delta, :content) - @streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content) - end + @prompt_tokens ||= json.dig(:usage, :prompt_tokens) + @completion_tokens ||= json.dig(:usage, :completion_tokens) - def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data: ", 2)[1] - data == "[DONE]" ? nil : data + if !text.to_s.empty? + text + else + nil + end end + .flatten .compact end end diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb index 57fcf0518..6b371a09d 100644 --- a/lib/completions/endpoints/vllm.rb +++ b/lib/completions/endpoints/vllm.rb @@ -42,7 +42,10 @@ def model_uri def prepare_payload(prompt, model_params, dialect) payload = default_options.merge(model_params).merge(messages: prompt) - payload[:stream] = true if @streaming_mode + if @streaming_mode + payload[:stream] = true if @streaming_mode + payload[:stream_options] = { include_usage: true } + end payload end @@ -56,24 +59,42 @@ def prepare_request(payload) Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data: ", 2)[1] - data == "[DONE]" ? nil : data - end - .compact + def xml_tools_enabled? + true + end + + def final_log_update(log) + log.request_tokens = @prompt_tokens if @prompt_tokens + log.response_tokens = @completion_tokens if @completion_tokens end - def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - # half a line sent here - return if !parsed + def decode(response_raw) + json = JSON.parse(response_raw, symbolize_names: true) + @prompt_tokens = json.dig(:usage, :prompt_tokens) + @completion_tokens = json.dig(:usage, :completion_tokens) + [json.dig(:choices, 0, :message, :content)] + end + + def decode_chunk(chunk) + @json_decoder ||= JsonStreamDecoder.new + (@json_decoder << chunk) + .map do |parsed| + # vLLM keeps sending usage over and over again + prompt_tokens = parsed.dig(:usage, :prompt_tokens) + completion_tokens = parsed.dig(:usage, :completion_tokens) + + @prompt_tokens = prompt_tokens if prompt_tokens - response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) + @completion_tokens = completion_tokens if completion_tokens - response_h.dig(:content) + text = parsed.dig(:choices, 0, :delta, :content) + if text.to_s.empty? + nil + else + text + end + end + .compact end end end diff --git a/lib/completions/function_call_normalizer.rb b/lib/completions/function_call_normalizer.rb deleted file mode 100644 index ef40809ca..000000000 --- a/lib/completions/function_call_normalizer.rb +++ /dev/null @@ -1,113 +0,0 @@ -# frozen_string_literal: true - -class DiscourseAi::Completions::FunctionCallNormalizer - attr_reader :done - - # blk is the block to call with filtered data - def initialize(blk, cancel) - @blk = blk - @cancel = cancel - @done = false - - @in_tool = false - - @buffer = +"" - @function_buffer = +"" - end - - def self.normalize(data) - text = +"" - cancel = -> {} - blk = ->(partial, _) { text << partial } - - normalizer = self.new(blk, cancel) - normalizer << data - - [text, normalizer.function_calls] - end - - def function_calls - return nil if @function_buffer.blank? - - xml = Nokogiri::HTML5.fragment(@function_buffer) - self.class.normalize_function_ids!(xml) - last_invoke = xml.at("invoke:last") - if last_invoke - last_invoke.next_sibling.remove while last_invoke.next_sibling - xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling - end - xml.at("function_calls").to_s.dup.force_encoding("UTF-8") - end - - def <<(text) - @buffer << text - - if !@in_tool - # double check if we are clearly in a tool - search_length = text.length + 20 - search_string = @buffer[-search_length..-1] || @buffer - - index = search_string.rindex("") - @in_tool = !!index - if @in_tool - @function_buffer = @buffer[index..-1] - text_index = text.rindex("") - @blk.call(text[0..text_index - 1].strip, @cancel) if text_index && text_index > 0 - end - else - @function_buffer << text - end - - if !@in_tool - if maybe_has_tool?(@buffer) - split_index = text.rindex("<").to_i - 1 - if split_index >= 0 - @function_buffer = text[split_index + 1..-1] || "" - text = text[0..split_index] || "" - else - @function_buffer << text - text = "" - end - else - if @function_buffer.length > 0 - @blk.call(@function_buffer, @cancel) - @function_buffer = +"" - end - end - - @blk.call(text, @cancel) if text.length > 0 - else - if text.include?("") - @done = true - @cancel.call - end - end - end - - def self.normalize_function_ids!(function_buffer) - function_buffer - .css("invoke") - .each_with_index do |invoke, index| - if invoke.at("tool_id") - invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank? - else - invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id") - end - end - end - - private - - def maybe_has_tool?(text) - # 16 is the length of function calls - substring = text[-16..-1] || text - split = substring.split("<") - - if split.length > 1 - match = "<" + split.last - "".start_with?(match) - else - substring.ends_with?("<") - end - end -end diff --git a/lib/completions/json_stream_decoder.rb b/lib/completions/json_stream_decoder.rb new file mode 100644 index 000000000..e575a3b78 --- /dev/null +++ b/lib/completions/json_stream_decoder.rb @@ -0,0 +1,48 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + # will work for anthropic and open ai compatible + class JsonStreamDecoder + attr_reader :buffer + + LINE_REGEX = /data: ({.*})\s*$/ + + def initialize(symbolize_keys: true, line_regex: LINE_REGEX) + @symbolize_keys = symbolize_keys + @buffer = +"" + @line_regex = line_regex + end + + def <<(raw) + @buffer << raw.to_s + rval = [] + + split = @buffer.scan(/.*\n?/) + split.pop if split.last.blank? + + @buffer = +(split.pop.to_s) + + split.each do |line| + matches = line.match(@line_regex) + next if !matches + rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys) + end + + if @buffer.present? + matches = @buffer.match(@line_regex) + if matches + begin + rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys) + @buffer = +"" + rescue JSON::ParserError + # maybe it is a partial line + end + end + end + + rval + end + end + end +end diff --git a/lib/completions/open_ai_message_processor.rb b/lib/completions/open_ai_message_processor.rb new file mode 100644 index 000000000..02369bec8 --- /dev/null +++ b/lib/completions/open_ai_message_processor.rb @@ -0,0 +1,103 @@ +# frozen_string_literal: true +module DiscourseAi::Completions + class OpenAiMessageProcessor + attr_reader :prompt_tokens, :completion_tokens + + def initialize + @tool = nil + @tool_arguments = +"" + @prompt_tokens = nil + @completion_tokens = nil + end + + def process_message(json) + result = [] + tool_calls = json.dig(:choices, 0, :message, :tool_calls) + + message = json.dig(:choices, 0, :message, :content) + result << message if message.present? + + if tool_calls.present? + tool_calls.each do |tool_call| + id = tool_call.dig(:id) + name = tool_call.dig(:function, :name) + arguments = tool_call.dig(:function, :arguments) + parameters = arguments.present? ? JSON.parse(arguments, symbolize_names: true) : {} + result << ToolCall.new(id: id, name: name, parameters: parameters) + end + end + + update_usage(json) + + result + end + + def process_streamed_message(json) + rval = nil + + tool_calls = json.dig(:choices, 0, :delta, :tool_calls) + content = json.dig(:choices, 0, :delta, :content) + + finished_tools = json.dig(:choices, 0, :finish_reason) || tool_calls == [] + + if tool_calls.present? + id = tool_calls.dig(0, :id) + name = tool_calls.dig(0, :function, :name) + arguments = tool_calls.dig(0, :function, :arguments) + + # TODO: multiple tool support may require index + #index = tool_calls[0].dig(:index) + + if id.present? && @tool && @tool.id != id + process_arguments + rval = @tool + @tool = nil + end + + if id.present? && name.present? + @tool_arguments = +"" + @tool = ToolCall.new(id: id, name: name) + end + + @tool_arguments << arguments.to_s + elsif finished_tools && @tool + parsed_args = JSON.parse(@tool_arguments, symbolize_names: true) + @tool.parameters = parsed_args + rval = @tool + @tool = nil + elsif content.present? + rval = content + end + + update_usage(json) + + rval + end + + def finish + rval = [] + if @tool + process_arguments + rval << @tool + @tool = nil + end + + rval + end + + private + + def process_arguments + if @tool_arguments.present? + parsed_args = JSON.parse(@tool_arguments, symbolize_names: true) + @tool.parameters = parsed_args + @tool_arguments = nil + end + end + + def update_usage(json) + @prompt_tokens ||= json.dig(:usage, :prompt_tokens) + @completion_tokens ||= json.dig(:usage, :completion_tokens) + end + end +end diff --git a/lib/completions/tool_call.rb b/lib/completions/tool_call.rb new file mode 100644 index 000000000..15be7b3f0 --- /dev/null +++ b/lib/completions/tool_call.rb @@ -0,0 +1,29 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + class ToolCall + attr_reader :id, :name, :parameters + + def initialize(id:, name:, parameters: nil) + @id = id + @name = name + self.parameters = parameters if parameters + @parameters ||= {} + end + + def parameters=(parameters) + raise ArgumentError, "parameters must be a hash" unless parameters.is_a?(Hash) + @parameters = parameters.symbolize_keys + end + + def ==(other) + id == other.id && name == other.name && parameters == other.parameters + end + + def to_s + "#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)" + end + end + end +end diff --git a/lib/completions/xml_tool_processor.rb b/lib/completions/xml_tool_processor.rb new file mode 100644 index 000000000..1b42b333c --- /dev/null +++ b/lib/completions/xml_tool_processor.rb @@ -0,0 +1,124 @@ +# frozen_string_literal: true + +# This class can be used to process a stream of text that may contain XML tool +# calls. +# It will return either text or ToolCall objects. + +module DiscourseAi + module Completions + class XmlToolProcessor + def initialize + @buffer = +"" + @function_buffer = +"" + @should_cancel = false + @in_tool = false + end + + def <<(text) + @buffer << text + result = [] + + if !@in_tool + # double check if we are clearly in a tool + search_length = text.length + 20 + search_string = @buffer[-search_length..-1] || @buffer + + index = search_string.rindex("") + @in_tool = !!index + if @in_tool + @function_buffer = @buffer[index..-1] + text_index = text.rindex("") + result << text[0..text_index - 1].strip if text_index && text_index > 0 + end + else + @function_buffer << text + end + + if !@in_tool + if maybe_has_tool?(@buffer) + split_index = text.rindex("<").to_i - 1 + if split_index >= 0 + @function_buffer = text[split_index + 1..-1] || "" + text = text[0..split_index] || "" + else + @function_buffer << text + text = "" + end + else + if @function_buffer.length > 0 + result << @function_buffer + @function_buffer = +"" + end + end + + result << text if text.length > 0 + else + @should_cancel = true if text.include?("") + end + + result + end + + def finish + return [] if @function_buffer.blank? + + xml = Nokogiri::HTML5.fragment(@function_buffer) + normalize_function_ids!(xml) + last_invoke = xml.at("invoke:last") + if last_invoke + last_invoke.next_sibling.remove while last_invoke.next_sibling + xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling + end + + xml + .css("invoke") + .map do |invoke| + tool_name = invoke.at("tool_name").content.force_encoding("UTF-8") + tool_id = invoke.at("tool_id").content.force_encoding("UTF-8") + parameters = {} + invoke + .at("parameters") + &.children + &.each do |node| + next if node.text? + name = node.name + value = node.content.to_s + parameters[name.to_sym] = value.to_s.force_encoding("UTF-8") + end + ToolCall.new(id: tool_id, name: tool_name, parameters: parameters) + end + end + + def should_cancel? + @should_cancel + end + + private + + def normalize_function_ids!(function_buffer) + function_buffer + .css("invoke") + .each_with_index do |invoke, index| + if invoke.at("tool_id") + invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank? + else + invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id") + end + end + end + + def maybe_has_tool?(text) + # 16 is the length of function calls + substring = text[-16..-1] || text + split = substring.split("<") + + if split.length > 1 + match = "<" + split.last + "".start_with?(match) + else + substring.ends_with?("<") + end + end + end + end +end diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index 94d5d6559..40eca30f6 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -104,7 +104,7 @@ data: {"type":"message_stop"} STRING - result = +"" + result = [] body = body.scan(/.*\n/) EndpointMock.with_chunk_array_support do stub_request(:post, url).to_return(status: 200, body: body) @@ -114,18 +114,17 @@ end end - expected = (<<~TEXT).strip - - - search - s<a>m sam - general - toolu_01DjrShFRRHp9SnHYRFRc53F - - - TEXT - - expect(result.strip).to eq(expected) + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "toolu_01DjrShFRRHp9SnHYRFRc53F", + parameters: { + search_query: "sm sam", + category: "general", + }, + ) + + expect(result).to eq([tool_call]) end it "can stream a response" do @@ -191,6 +190,8 @@ expect(log.feature_name).to eq("testing") expect(log.response_tokens).to eq(15) expect(log.request_tokens).to eq(25) + expect(log.raw_request_payload).to eq(expected_body.to_json) + expect(log.raw_response_payload.strip).to eq(body.strip) end it "supports non streaming tool calls" do @@ -242,17 +243,20 @@ result = llm.generate(prompt, user: Discourse.system_user) - expected = <<~TEXT.strip - - - calculate - 2758975 + 21.11 - toolu_012kBdhG4eHaV68W56p4N94h - - - TEXT - - expect(result.strip).to eq(expected) + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "calculate", + id: "toolu_012kBdhG4eHaV68W56p4N94h", + parameters: { + expression: "2758975 + 21.11", + }, + ) + + expect(result).to eq(["Here is the calculation:", tool_call]) + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(345) + expect(log.response_tokens).to eq(65) end it "can send images via a completion prompt" do diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index d95193447..2a9cc77fa 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -79,7 +79,7 @@ def encode_message(message) } prompt.tools = [tool] - response = +"" + response = [] proxy.generate(prompt, user: user) { |partial| response << partial } expect(request.headers["Authorization"]).to be_present @@ -90,21 +90,18 @@ def encode_message(message) expect(parsed_body["tools"]).to eq(nil) expect(parsed_body["stop_sequences"]).to eq([""]) - # note we now have a tool_id cause we were normalized - function_call = <<~XML.strip - hello - - - - - google - sydney weather today - tool_0 - - - XML + expected = [ + "hello\n", + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "google", + parameters: { + query: "sydney weather today", + }, + ), + ] - expect(response.strip).to eq(function_call) + expect(response).to eq(expected) end end @@ -230,23 +227,23 @@ def encode_message(message) } prompt.tools = [tool] - response = +"" + response = [] proxy.generate(prompt, user: user) { |partial| response << partial } expect(request.headers["Authorization"]).to be_present expect(request.headers["X-Amz-Content-Sha256"]).to be_present - expected_response = (<<~RESPONSE).strip - - - google - sydney weather today - toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7 - - - RESPONSE + expected_response = [ + DiscourseAi::Completions::ToolCall.new( + id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7", + name: "google", + parameters: { + query: "sydney weather today", + }, + ), + ] - expect(response.strip).to eq(expected_response) + expect(response).to eq(expected_response) expected = { "max_tokens" => 3000, diff --git a/spec/lib/completions/endpoints/cohere_spec.rb b/spec/lib/completions/endpoints/cohere_spec.rb index 4bb213ffe..bdff8fc3c 100644 --- a/spec/lib/completions/endpoints/cohere_spec.rb +++ b/spec/lib/completions/endpoints/cohere_spec.rb @@ -66,7 +66,7 @@ TEXT parsed_body = nil - result = +"" + result = [] sig = { name: "google", @@ -91,21 +91,20 @@ }, ).to_return(status: 200, body: body.split("|")) - result = llm.generate(prompt, user: user) { |partial, cancel| result << partial } + llm.generate(prompt, user: user) { |partial, cancel| result << partial } end - expected = <<~TEXT - - - google - who is sam saffron - - tool_0 - - - TEXT + text = "I will search for 'who is sam saffron' and relay the information to the user." + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "google", + parameters: { + query: "who is sam saffron", + }, + ) - expect(result.strip).to eq(expected.strip) + expect(result).to eq([text, tool_call]) expected = { model: "command-r-plus", diff --git a/spec/lib/completions/endpoints/endpoint_compliance.rb b/spec/lib/completions/endpoints/endpoint_compliance.rb index 372c529b2..130c735b6 100644 --- a/spec/lib/completions/endpoints/endpoint_compliance.rb +++ b/spec/lib/completions/endpoints/endpoint_compliance.rb @@ -62,18 +62,14 @@ def tool_response end def invocation_response - <<~TEXT - - - get_weather - - Sydney - c - - tool_0 - - - TEXT + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "get_weather", + parameters: { + location: "Sydney", + unit: "c", + }, + ) end def tool_id @@ -185,7 +181,7 @@ def regular_mode_tools(mock) mock.stub_tool_call(a_dialect.translate) completion_response = endpoint.perform_completion!(a_dialect, user) - expect(completion_response.strip).to eq(mock.invocation_response.strip) + expect(completion_response).to eq(mock.invocation_response) end def streaming_mode_simple_prompt(mock) @@ -205,6 +201,7 @@ def streaming_mode_simple_prompt(mock) expect(log.raw_request_payload).to be_present expect(log.raw_response_payload).to be_present expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate)) + expect(log.response_tokens).to eq( endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join), ) @@ -216,14 +213,14 @@ def streaming_mode_tools(mock) a_dialect = dialect(prompt: prompt) mock.stub_streamed_tool_call(a_dialect.translate) do - buffered_partial = +"" + buffered_partial = [] endpoint.perform_completion!(a_dialect, user) do |partial, cancel| buffered_partial << partial - cancel.call if buffered_partial.include?("") + cancel.call if partial.is_a?(DiscourseAi::Completions::ToolCall) end - expect(buffered_partial.strip).to eq(mock.invocation_response.strip) + expect(buffered_partial).to eq([mock.invocation_response]) end end diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index 2f602d3a4..189338438 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -195,19 +195,16 @@ def tool_response response = llm.generate(prompt, user: user) - expected = (<<~XML).strip - - - echo - - <S>ydney - - tool_0 - - - XML - - expect(response.strip).to eq(expected) + tool = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "echo", + parameters: { + text: "ydney", + }, + ) + + expect(response).to eq(tool) end it "Supports Vision API" do @@ -265,6 +262,68 @@ def tool_response expect(JSON.parse(req_body)).to eq(expected_prompt) end + it "Can stream tool calls correctly" do + rows = [ + { + candidates: [ + { + content: { + parts: [{ functionCall: { name: "echo", args: { text: "sam<>wh!s" } } }], + role: "model", + }, + safetyRatings: [ + { category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" }, + ], + }, + ], + usageMetadata: { + promptTokenCount: 625, + totalTokenCount: 625, + }, + modelVersion: "gemini-1.5-pro-002", + }, + { + candidates: [{ content: { parts: [{ text: "" }], role: "model" }, finishReason: "STOP" }], + usageMetadata: { + promptTokenCount: 625, + candidatesTokenCount: 4, + totalTokenCount: 629, + }, + modelVersion: "gemini-1.5-pro-002", + }, + ] + + payload = rows.map { |r| "data: #{r.to_json}\n\n" }.join + + llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") + url = "#{model.url}:streamGenerateContent?alt=sse&key=123" + + prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool]) + + output = [] + + stub_request(:post, url).to_return(status: 200, body: payload) + llm.generate(prompt, user: user) { |partial| output << partial } + + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "echo", + parameters: { + text: "sam<>wh!s", + }, + ) + + expect(output).to eq([tool_call]) + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(625) + expect(log.response_tokens).to eq(4) + end + it "Can correctly handle streamed responses even if they are chunked badly" do data = +"" data << "da|ta: |" @@ -279,12 +338,12 @@ def tool_response llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") url = "#{model.url}:streamGenerateContent?alt=sse&key=123" - output = +"" + output = [] gemini_mock.with_chunk_array_support do stub_request(:post, url).to_return(status: 200, body: split) llm.generate("Hello", user: user) { |partial| output << partial } end - expect(output).to eq("Hello World Sam") + expect(output.join).to eq("Hello World Sam") end end diff --git a/spec/lib/completions/endpoints/ollama_spec.rb b/spec/lib/completions/endpoints/ollama_spec.rb index eb6bc63c5..4f458283e 100644 --- a/spec/lib/completions/endpoints/ollama_spec.rb +++ b/spec/lib/completions/endpoints/ollama_spec.rb @@ -150,7 +150,7 @@ def request_body(prompt, tool_call: false) end describe "when using streaming mode" do - context "with simpel prompts" do + context "with simple prompts" do it "completes a trivial prompt and logs the response" do compliance.streaming_mode_simple_prompt(ollama_mock) end diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index 60df1d675..c4d7758a6 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -17,8 +17,8 @@ def response(content, tool_call: false) created: 1_678_464_820, model: "gpt-3.5-turbo-0301", usage: { - prompt_tokens: 337, - completion_tokens: 162, + prompt_tokens: 8, + completion_tokens: 13, total_tokens: 499, }, choices: [ @@ -231,19 +231,16 @@ def request_body(prompt, stream: false, tool_call: false) result = llm.generate(prompt, user: user) - expected = (<<~TXT).strip - - - echo - - hello - - call_I8LKnoijVuhKOM85nnEQgWwd - - - TXT + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "call_I8LKnoijVuhKOM85nnEQgWwd", + name: "echo", + parameters: { + text: "hello", + }, + ) - expect(result.strip).to eq(expected) + expect(result).to eq(tool_call) stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( body: { choices: [message: { content: "OK" }] }.to_json, @@ -320,19 +317,20 @@ def request_body(prompt, stream: false, tool_call: false) expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } }) - expected = (<<~TXT).strip - - - echo - - h<e>llo - - call_I8LKnoijVuhKOM85nnEQgWwd - - - TXT + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(55) + expect(log.response_tokens).to eq(13) + + expected = + DiscourseAi::Completions::ToolCall.new( + id: "call_I8LKnoijVuhKOM85nnEQgWwd", + name: "echo", + parameters: { + text: "hllo", + }, + ) - expect(result.strip).to eq(expected) + expect(result).to eq(expected) stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( body: { choices: [message: { content: "OK" }] }.to_json, @@ -487,7 +485,7 @@ def request_body(prompt, stream: false, tool_call: false) data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot\\"}"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot2\\"}"}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]} @@ -495,32 +493,30 @@ def request_body(prompt, stream: false, tool_call: false) TEXT open_ai_mock.stub_raw(raw_data) - content = +"" + response = [] dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools)) - endpoint.perform_completion!(dialect, user) { |partial| content << partial } - - expected = <<~TEXT - - - search - - Discourse AI bot - - call_3Gyr3HylFJwfrtKrL6NaIit1 - - - search - - Discourse AI bot - - call_H7YkbgYurHpyJqzwUN4bghwN - - - TEXT + endpoint.perform_completion!(dialect, user) { |partial| response << partial } + + tool_calls = [ + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_3Gyr3HylFJwfrtKrL6NaIit1", + parameters: { + search_query: "Discourse AI bot", + }, + ), + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_H7YkbgYurHpyJqzwUN4bghwN", + parameters: { + query: "Discourse AI bot2", + }, + ), + ] - expect(content).to eq(expected) + expect(response).to eq(tool_calls) end it "uses proper token accounting" do @@ -593,21 +589,16 @@ def request_body(prompt, stream: false, tool_call: false) dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools)) endpoint.perform_completion!(dialect, user) { |partial| partials << partial } - expect(partials.length).to eq(1) - - function_call = (<<~TXT).strip - - - google - - Adabas 9.1 - - func_id - - - TXT - - expect(partials[0].strip).to eq(function_call) + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "func_id", + name: "google", + parameters: { + query: "Adabas 9.1", + }, + ) + + expect(partials).to eq([tool_call]) end end end diff --git a/spec/lib/completions/endpoints/samba_nova_spec.rb b/spec/lib/completions/endpoints/samba_nova_spec.rb index 0f1f68acc..83839bf47 100644 --- a/spec/lib/completions/endpoints/samba_nova_spec.rb +++ b/spec/lib/completions/endpoints/samba_nova_spec.rb @@ -22,10 +22,15 @@ }, ).to_return(status: 200, body: body, headers: {}) - response = +"" + response = [] llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial } - expect(response).to eq("I am a bot") + expect(response).to eq(["I am a bot"]) + + log = AiApiAuditLog.order(:id).last + + expect(log.request_tokens).to eq(21) + expect(log.response_tokens).to eq(41) end it "can perform regular completions" do diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb index 6f5387c0d..824bcbe06 100644 --- a/spec/lib/completions/endpoints/vllm_spec.rb +++ b/spec/lib/completions/endpoints/vllm_spec.rb @@ -51,7 +51,13 @@ def stub_streamed_response(prompt, deltas, tool_call: false) WebMock .stub_request(:post, "https://test.dev/v1/chat/completions") - .with(body: model.default_options.merge(messages: prompt, stream: true).to_json) + .with( + body: + model + .default_options + .merge(messages: prompt, stream: true, stream_options: { include_usage: true }) + .to_json, + ) .to_return(status: 200, body: chunks) end end @@ -136,29 +142,115 @@ def stub_streamed_response(prompt, deltas, tool_call: false) result = llm.generate(prompt, user: Discourse.system_user) - expected = <<~TEXT - - - calculate - - 1+1 - tool_0 - - - TEXT + expected = + DiscourseAi::Completions::ToolCall.new( + name: "calculate", + id: "tool_0", + parameters: { + expression: "1+1", + }, + ) - expect(result.strip).to eq(expected.strip) + expect(result).to eq(expected) end end + it "correctly accounts for tokens in non streaming mode" do + body = (<<~TEXT).strip + {"id":"chat-c580e4a9ebaa44a0becc802ed5dc213a","object":"chat.completion","created":1731294404,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Random Number Generator Produces Smallest Possible Result","tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":146,"total_tokens":156,"completion_tokens":10},"prompt_logprobs":null} + TEXT + + stub_request(:post, "https://test.dev/v1/chat/completions").to_return(status: 200, body: body) + + result = llm.generate("generate a title", user: Discourse.system_user) + + expect(result).to eq("Random Number Generator Produces Smallest Possible Result") + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(146) + expect(log.response_tokens).to eq(10) + end + + it "can properly include usage in streaming mode" do + payload = <<~TEXT.strip + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":46,"completion_tokens":0}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":47,"completion_tokens":1}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Sam"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":48,"completion_tokens":2}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":49,"completion_tokens":3}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" It"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":50,"completion_tokens":4}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"'s"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":51,"completion_tokens":5}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" nice"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":52,"completion_tokens":6}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":53,"completion_tokens":7}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" meet"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":54,"completion_tokens":8}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":55,"completion_tokens":9}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":56,"completion_tokens":10}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Is"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":57,"completion_tokens":11}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" there"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":58,"completion_tokens":12}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" something"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":59,"completion_tokens":13}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":60,"completion_tokens":14}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":61,"completion_tokens":15}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" help"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":62,"completion_tokens":16}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":63,"completion_tokens":17}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" with"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":64,"completion_tokens":18}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" or"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":65,"completion_tokens":19}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" would"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":66,"completion_tokens":20}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":67,"completion_tokens":21}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" like"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":68,"completion_tokens":22}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":69,"completion_tokens":23}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" chat"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":70,"completion_tokens":24}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":71,"completion_tokens":25}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":""},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}} + + data: [DONE] + TEXT + + stub_request(:post, "https://test.dev/v1/chat/completions").to_return( + status: 200, + body: payload, + ) + + response = [] + llm.generate("say hello", user: Discourse.system_user) { |partial| response << partial } + + expect(response.join).to eq( + "Hello Sam. It's nice to meet you. Is there something I can help you with or would you like to chat?", + ) + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(46) + expect(log.response_tokens).to eq(26) + end + describe "#perform_completion!" do context "when using regular mode" do - context "with simple prompts" do - it "completes a trivial prompt and logs the response" do - compliance.regular_mode_simple_prompt(vllm_mock) - end - end - context "with tools" do it "returns a function invocation" do compliance.regular_mode_tools(vllm_mock) diff --git a/spec/lib/completions/function_call_normalizer_spec.rb b/spec/lib/completions/function_call_normalizer_spec.rb deleted file mode 100644 index dd78ed7f2..000000000 --- a/spec/lib/completions/function_call_normalizer_spec.rb +++ /dev/null @@ -1,182 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe DiscourseAi::Completions::FunctionCallNormalizer do - let(:buffer) { +"" } - - let(:normalizer) do - blk = ->(data, cancel) { buffer << data } - cancel = -> { @done = true } - DiscourseAi::Completions::FunctionCallNormalizer.new(blk, cancel) - end - - def pass_through!(data) - normalizer << data - expect(buffer[-data.length..-1]).to eq(data) - end - - it "is usable in non streaming mode" do - xml = (<<~XML).strip - hello - - - hello - - XML - - text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml) - - expect(text).to eq("hello") - - expected_function_calls = (<<~XML).strip - - - hello - tool_0 - - - XML - - expect(function_calls).to eq(expected_function_calls) - end - - it "strips junk from end of function calls" do - xml = (<<~XML).strip - hello - - - hello - - junk - XML - - _text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml) - - expected_function_calls = (<<~XML).strip - - - hello - tool_0 - - - XML - - expect(function_calls).to eq(expected_function_calls) - end - - it "returns nil for function calls if there are none" do - input = "hello world\n" - text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(input) - - expect(text).to eq(input) - expect(function_calls).to eq(nil) - end - - it "passes through data if there are no function calls detected" do - pass_through!("hello") - pass_through!("hello") - pass_through!("world") - pass_through!("") - end - - it "properly handles non English tools" do - normalizer << "hello\n" - - normalizer << (<<~XML).strip - - hello - - 世界 - - - XML - - expected = (<<~XML).strip - - - hello - - 世界 - - tool_0 - - - XML - - function_calls = normalizer.function_calls - expect(function_calls).to eq(expected) - end - - it "works correctly even if you only give it 1 letter at a time" do - xml = (<<~XML).strip - abc - - - hello - - world - - abc - - - hello2 - - world - - aba - - - XML - - xml.each_char { |char| normalizer << char } - - expect(buffer + normalizer.function_calls).to eq(xml) - end - - it "supports multiple invokes" do - xml = (<<~XML).strip - - - hello - - world - - abc - - - hello2 - - world - - aba - - - XML - - normalizer << xml - - expect(normalizer.function_calls).to eq(xml) - end - - it "can will cancel if it encounteres " do - normalizer << "" - expect(normalizer.done).to eq(false) - normalizer << "" - expect(normalizer.done).to eq(true) - expect(@done).to eq(true) - - expect(normalizer.function_calls).to eq("") - end - - it "pauses on function call and starts buffering" do - normalizer << "hello" - expect(buffer).to eq("hello") - expect(normalizer.done).to eq(false) - end -end diff --git a/spec/lib/completions/json_stream_decoder_spec.rb b/spec/lib/completions/json_stream_decoder_spec.rb new file mode 100644 index 000000000..831bad6f3 --- /dev/null +++ b/spec/lib/completions/json_stream_decoder_spec.rb @@ -0,0 +1,47 @@ +# frozen_string_literal: true + +describe DiscourseAi::Completions::JsonStreamDecoder do + let(:decoder) { DiscourseAi::Completions::JsonStreamDecoder.new } + + it "should be able to parse simple messages" do + result = decoder << "data: #{{ hello: "world" }.to_json}" + expect(result).to eq([{ hello: "world" }]) + end + + it "should handle anthropic mixed stlye streams" do + stream = (<<~TEXT).split("|") + event: |message_start| + data: |{"hel|lo": "world"}| + + event: |message_start + data: {"foo": "bar"} + + event: |message_start + data: {"ba|z": "qux"|} + + [DONE] + TEXT + + results = [] + stream.each { |chunk| results << (decoder << chunk) } + + expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }]) + end + + it "should be able to handle complex overlaps" do + stream = (<<~TEXT).split("|") + data: |{"hel|lo": "world"} + + data: {"foo": "bar"} + + data: {"ba|z": "qux"|} + + [DONE] + TEXT + + results = [] + stream.each { |chunk| results << (decoder << chunk) } + + expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }]) + end +end diff --git a/spec/lib/completions/xml_tool_processor_spec.rb b/spec/lib/completions/xml_tool_processor_spec.rb new file mode 100644 index 000000000..003f4356c --- /dev/null +++ b/spec/lib/completions/xml_tool_processor_spec.rb @@ -0,0 +1,188 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Completions::XmlToolProcessor do + let(:processor) { DiscourseAi::Completions::XmlToolProcessor.new } + + it "can process simple text" do + result = [] + result << (processor << "hello") + result << (processor << " world ") + expect(result).to eq([["hello"], [" world "]]) + expect(processor.finish).to eq([]) + expect(processor.should_cancel?).to eq(false) + end + + it "is usable for simple single message mode" do + xml = (<<~XML).strip + hello + + + hello + + world + value + + + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "hello", + parameters: { + hello: "world", + test: "value", + }, + ) + expect(result).to eq([["hello"], [tool_call]]) + expect(processor.should_cancel?).to eq(false) + end + + it "handles multiple tool calls in sequence" do + xml = (<<~XML).strip + start + + + first_tool + + value1 + + + + second_tool + + value2 + + + + end + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + first_tool = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "first_tool", + parameters: { + param1: "value1", + }, + ) + + second_tool = + DiscourseAi::Completions::ToolCall.new( + id: "tool_1", + name: "second_tool", + parameters: { + param2: "value2", + }, + ) + + expect(result).to eq([["start"], [first_tool, second_tool]]) + expect(processor.should_cancel?).to eq(true) + end + + it "handles non-English parameters correctly" do + xml = (<<~XML).strip + こんにちは + + + translator + + 世界 + + + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "translator", + parameters: { + text: "世界", + }, + ) + + expect(result).to eq([["こんにちは"], [tool_call]]) + end + + it "processes input character by character" do + xml = + "hitest

v

" + + result = [] + xml.each_char { |char| result << (processor << char) } + result << processor.finish + + tool_call = + DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { p: "v" }) + + filtered_result = result.reject(&:empty?) + expect(filtered_result).to eq([["h"], ["i"], [tool_call]]) + end + + it "handles malformed XML gracefully" do + xml = (<<~XML).strip + text + + + test + + value + + + malformed + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + # Should just do its best to parse the XML + tool_call = + DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { param: "" }) + expect(result).to eq([["text"], [tool_call]]) + end + + it "correctly processes empty parameter sets" do + xml = (<<~XML).strip + hello + + + no_params + + + + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + tool_call = + DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "no_params", parameters: {}) + + expect(result).to eq([["hello"], [tool_call]]) + end + + it "properly handles cancelled processing" do + xml = "start" + result = [] + result << (processor << xml) + result << (processor << "more text") + result << processor.finish + + expect(result).to eq([["start"], [], []]) + expect(processor.should_cancel?).to eq(true) + end +end diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index 2fb95d199..5271374d8 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -72,40 +72,27 @@ def system_prompt it "can parse string that are wrapped in quotes" do SiteSetting.ai_stability_api_key = "123" - xml = <<~XML - - - image - call_JtYQMful5QKqw97XFsHzPweB - - ["cat oil painting", "big car"] - "16:9" - - - - image - call_JtYQMful5QKqw97XFsHzPweB - - ["cat oil painting", "big car"] - '16:9' - - - - XML - image1, image2 = - tools = - DiscourseAi::AiBot::Personas::Artist.new.find_tools( - xml, - bot_user: nil, - llm: nil, - context: nil, - ) - expect(image1.parameters[:prompts]).to eq(["cat oil painting", "big car"]) - expect(image1.parameters[:aspect_ratio]).to eq("16:9") - expect(image2.parameters[:aspect_ratio]).to eq("16:9") + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "image", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + prompts: ["cat oil painting", "big car"], + aspect_ratio: "16:9", + }, + ) - expect(tools.length).to eq(2) + tool_instance = + DiscourseAi::AiBot::Personas::Artist.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) + + expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"]) + expect(tool_instance.parameters[:aspect_ratio]).to eq("16:9") end it "enforces enums" do @@ -132,42 +119,68 @@ def system_prompt XML - search1, search2 = - tools = - DiscourseAi::AiBot::Personas::General.new.find_tools( - xml, - bot_user: nil, - llm: nil, - context: nil, - ) + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + max_posts: "3.2", + status: "cow", + foo: "bar", + }, + ) + + tool_instance = + DiscourseAi::AiBot::Personas::General.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) + + expect(tool_instance.parameters.key?(:status)).to eq(false) + + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + max_posts: "3.2", + status: "open", + foo: "bar", + }, + ) - expect(search1.parameters.key?(:status)).to eq(false) - expect(search2.parameters[:status]).to eq("open") + tool_instance = + DiscourseAi::AiBot::Personas::General.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) + + expect(tool_instance.parameters[:status]).to eq("open") end it "can coerce integers" do - xml = <<~XML - - - search - call_JtYQMful5QKqw97XFsHzPweB - - "3.2" - hello world - bar - - - - XML + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + max_posts: "3.2", + search_query: "hello world", + foo: "bar", + }, + ) - search, = - tools = - DiscourseAi::AiBot::Personas::General.new.find_tools( - xml, - bot_user: nil, - llm: nil, - context: nil, - ) + search = + DiscourseAi::AiBot::Personas::General.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) expect(search.parameters[:max_posts]).to eq(3) expect(search.parameters[:search_query]).to eq("hello world") @@ -177,43 +190,23 @@ def system_prompt it "can correctly parse arrays in tools" do SiteSetting.ai_openai_api_key = "123" - # Dall E tool uses an array for params - xml = <<~XML - - - dall_e - call_JtYQMful5QKqw97XFsHzPweB - - ["cat oil painting", "big car"] - - - - dall_e - abc - - ["pic3"] - - - - unknown - abc - - ["pic3"] - - - - XML - dall_e1, dall_e2 = - tools = - DiscourseAi::AiBot::Personas::DallE3.new.find_tools( - xml, - bot_user: nil, - llm: nil, - context: nil, - ) - expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"]) - expect(dall_e2.parameters[:prompts]).to eq(["pic3"]) - expect(tools.length).to eq(2) + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "dall_e", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + prompts: ["cat oil painting", "big car"], + }, + ) + + tool_instance = + DiscourseAi::AiBot::Personas::DallE3.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) + expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"]) end describe "custom personas" do diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 9c98a08a7..2a07ad524 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -55,6 +55,8 @@ ) end + before { SiteSetting.ai_embeddings_enabled = false } + after do # we must reset cache on persona cause data can be rolled back AiPersona.persona_cache.flush! @@ -83,17 +85,15 @@ end let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) } - let(:function_call) { (<<~XML).strip } - - - search - 666 - - Can you use the custom tool - - - ", - XML + let(:tool_call) do + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "666", + parameters: { + query: "Can you use the custom tool", + }, + ) + end let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) } @@ -115,7 +115,7 @@ reply_post = nil prompts = nil - responses = [function_call] + responses = [tool_call] DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts| new_post = Fabricate(:post, raw: "Can you use the custom tool?") reply_post = playground.reply_to(new_post) @@ -133,7 +133,7 @@ it "can force usage of a tool" do tool_name = "custom-#{custom_tool.id}" ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1) - responses = [function_call, "custom tool did stuff (maybe)"] + responses = [tool_call, "custom tool did stuff (maybe)"] prompts = nil reply_post = nil @@ -166,7 +166,7 @@ bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) playground = DiscourseAi::AiBot::Playground.new(bot) - responses = [function_call, "custom tool did stuff (maybe)"] + responses = [tool_call, "custom tool did stuff (maybe)"] reply_post = nil @@ -206,13 +206,15 @@ bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) playground = DiscourseAi::AiBot::Playground.new(bot) + responses = ["custom tool did stuff (maybe)", tool_call] + # lets ensure tool does not run... DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt| new_post = Fabricate(:post, raw: "Can you use the custom tool?") reply_post = playground.reply_to(new_post) end - expect(reply_post.raw.strip).to eq(function_call) + expect(reply_post.raw.strip).to eq("custom tool did stuff (maybe)") end end @@ -452,10 +454,25 @@ it "can run tools" do persona.update!(tools: ["Time"]) - responses = [ - "timetimeBuenos Aires", - "The time is 2023-12-14 17:24:00 -0300", - ] + tool_call1 = + DiscourseAi::Completions::ToolCall.new( + name: "time", + id: "time", + parameters: { + timezone: "Buenos Aires", + }, + ) + + tool_call2 = + DiscourseAi::Completions::ToolCall.new( + name: "time", + id: "time", + parameters: { + timezone: "Sydney", + }, + ) + + responses = [[tool_call1, tool_call2], "The time is 2023-12-14 17:24:00 -0300"] message = DiscourseAi::Completions::Llm.with_prepared_responses(responses) do @@ -470,7 +487,8 @@ # it also needs to have tool details now set on message prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id) - expect(prompt.custom_prompt.length).to eq(3) + + expect(prompt.custom_prompt.length).to eq(5) # TODO in chat I am mixed on including this in the context, but I guess maybe? # thinking about this @@ -782,30 +800,29 @@ end it "supports multiple function calls" do - response1 = (<<~TXT).strip - - - search - search - - testing various things - - - - search - search - - another search - - - - TXT + tool_call1 = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "search", + parameters: { + search_query: "testing various things", + }, + ) + + tool_call2 = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "search", + parameters: { + search_query: "another search", + }, + ) response2 = "I found stuff" - DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do - playground.reply_to(third_post) - end + DiscourseAi::Completions::Llm.with_prepared_responses( + [[tool_call1, tool_call2], response2], + ) { playground.reply_to(third_post) } last_post = third_post.topic.reload.posts.order(:post_number).last @@ -819,17 +836,14 @@ bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona.class_instance.new) playground = described_class.new(bot) - response1 = (<<~TXT).strip - - - search - search - - testing various things - - - - TXT + response1 = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "search", + parameters: { + search_query: "testing various things", + }, + ) response2 = "I found stuff" @@ -843,17 +857,14 @@ end it "does not include placeholders in conversation context but includes all completions" do - response1 = (<<~TXT).strip - - - search - search - - testing various things - - - - TXT + response1 = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "search", + parameters: { + search_query: "testing various things", + }, + ) response2 = "I found some really amazing stuff!" @@ -889,17 +900,15 @@ [{ b64_json: image, revised_prompt: "a pink cow 1" }] end - let(:response) { (<<~TXT).strip } - - - dall_e - dall_e - - ["a pink cow"] - - - - TXT + let(:response) do + DiscourseAi::Completions::ToolCall.new( + name: "dall_e", + id: "dall_e", + parameters: { + prompts: ["a pink cow"], + }, + ) + end it "properly returns an image when skipping tool details" do persona.update!(tool_details: false) diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index fb42506e0..16e0001bb 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -541,16 +541,10 @@ def validate_streamed_response(raw_http, expected) expect(topic.title).to eq("An amazing title") expect(topic.posts.count).to eq(2) - # now let's try to make a reply with a tool call - function_call = <<~XML - - - categories - - - XML - - fake_endpoint.fake_content = [function_call, "this is the response after the tool"] + tool_call = + DiscourseAi::Completions::ToolCall.new(name: "categories", parameters: {}, id: "tool_1") + + fake_endpoint.fake_content = [tool_call, "this is the response after the tool"] # this simplifies function calls fake_endpoint.chunk_count = 1 diff --git a/spec/requests/ai_bot/bot_controller_spec.rb b/spec/requests/ai_bot/bot_controller_spec.rb index e74301857..007e868be 100644 --- a/spec/requests/ai_bot/bot_controller_spec.rb +++ b/spec/requests/ai_bot/bot_controller_spec.rb @@ -4,6 +4,8 @@ fab!(:user) fab!(:pm_topic) { Fabricate(:private_message_topic) } fab!(:pm_post) { Fabricate(:post, topic: pm_topic) } + fab!(:pm_post2) { Fabricate(:post, topic: pm_topic) } + fab!(:pm_post3) { Fabricate(:post, topic: pm_topic) } before { sign_in(user) } @@ -22,15 +24,37 @@ user = pm_topic.topic_allowed_users.first.user sign_in(user) - AiApiAuditLog.create!( - post_id: pm_post.id, - provider_id: 1, - topic_id: pm_topic.id, - raw_request_payload: "request", - raw_response_payload: "response", - request_tokens: 1, - response_tokens: 2, - ) + log1 = + AiApiAuditLog.create!( + provider_id: 1, + topic_id: pm_topic.id, + raw_request_payload: "request", + raw_response_payload: "response", + request_tokens: 1, + response_tokens: 2, + ) + + log2 = + AiApiAuditLog.create!( + post_id: pm_post.id, + provider_id: 1, + topic_id: pm_topic.id, + raw_request_payload: "request", + raw_response_payload: "response", + request_tokens: 1, + response_tokens: 2, + ) + + log3 = + AiApiAuditLog.create!( + post_id: pm_post2.id, + provider_id: 1, + topic_id: pm_topic.id, + raw_request_payload: "request", + raw_response_payload: "response", + request_tokens: 1, + response_tokens: 2, + ) Group.refresh_automatic_groups! SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s @@ -38,18 +62,26 @@ get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info" expect(response.status).to eq(200) + expect(response.parsed_body["id"]).to eq(log2.id) + expect(response.parsed_body["next_log_id"]).to eq(log3.id) + expect(response.parsed_body["prev_log_id"]).to eq(log1.id) + expect(response.parsed_body["topic_id"]).to eq(pm_topic.id) + expect(response.parsed_body["request_tokens"]).to eq(1) expect(response.parsed_body["response_tokens"]).to eq(2) expect(response.parsed_body["raw_request_payload"]).to eq("request") expect(response.parsed_body["raw_response_payload"]).to eq("response") - post2 = Fabricate(:post, topic: pm_topic) - # return previous post if current has no debug info - get "/discourse-ai/ai-bot/post/#{post2.id}/show-debug-info" + get "/discourse-ai/ai-bot/post/#{pm_post3.id}/show-debug-info" expect(response.status).to eq(200) expect(response.parsed_body["request_tokens"]).to eq(1) expect(response.parsed_body["response_tokens"]).to eq(2) + + # can return debug info by id as well + get "/discourse-ai/ai-bot/show-debug-info/#{log1.id}" + expect(response.status).to eq(200) + expect(response.parsed_body["id"]).to eq(log1.id) end end