Skip to content

Commit

Permalink
FEATURE: improve tool support (#904)
Browse files Browse the repository at this point in the history
This re-implements tool support in DiscourseAi::Completions::Llm #generate

Previously tool support was always returned via XML and it would be the responsibility of the caller to parse XML

New implementation has the endpoints return ToolCall objects.

Additionally this simplifies the Llm endpoint interface and gives it more clarity. Llms must implement

decode, decode_chunk (for streaming)

It is the implementers responsibility to figure out how to decode chunks, base no longer implements. To make this easy we ship a flexible json decoder which is easy to wire up.

Also (new)

    Better debugging for PMs, we now have a next / previous button to see all the Llm messages associated with a PM
    Token accounting is fixed for vllm (we were not correctly counting tokens)
  • Loading branch information
SamSaffron authored Nov 11, 2024
1 parent 644141f commit e817b7d
Show file tree
Hide file tree
Showing 43 changed files with 1,685 additions and 1,293 deletions.
8 changes: 8 additions & 0 deletions app/controllers/discourse_ai/ai_bot/bot_controller.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ class BotController < ::ApplicationController
requires_plugin ::DiscourseAi::PLUGIN_NAME
requires_login

def show_debug_info_by_id
log = AiApiAuditLog.find(params[:id])
raise Discourse::NotFound if !log.topic

guardian.ensure_can_debug_ai_bot_conversation!(log.topic)
render json: AiApiAuditLogSerializer.new(log, root: false), status: 200
end

def show_debug_info
post = Post.find(params[:post_id])
guardian.ensure_can_debug_ai_bot_conversation!(post)
Expand Down
8 changes: 8 additions & 0 deletions app/models/ai_api_audit_log.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ module Provider
Ollama = 7
SambaNova = 8
end

def next_log_id
self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first
end

def prev_log_id
self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first
end
end

# == Schema Information
Expand Down
4 changes: 3 additions & 1 deletion app/serializers/ai_api_audit_log_serializer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer
:post_id,
:feature_name,
:language_model,
:created_at
:created_at,
:prev_log_id,
:next_log_id
end
51 changes: 46 additions & 5 deletions assets/javascripts/discourse/components/modal/debug-ai-modal.gjs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template";
import DButton from "discourse/components/d-button";
import DModal from "discourse/components/d-modal";
import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error";
import { clipboardCopy, escapeExpression } from "discourse/lib/utilities";
import i18n from "discourse-common/helpers/i18n";
import discourseLater from "discourse-common/lib/later";
Expand Down Expand Up @@ -63,6 +64,28 @@ export default class DebugAiModal extends Component {
this.copy(this.info.raw_response_payload);
}

async loadLog(logId) {
try {
await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then(
(result) => {
this.info = result;
}
);
} catch (e) {
popupAjaxError(e);
}
}

@action
prevLog() {
this.loadLog(this.info.prev_log_id);
}

@action
nextLog() {
this.loadLog(this.info.next_log_id);
}

copy(text) {
clipboardCopy(text);
this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared");
Expand All @@ -73,11 +96,13 @@ export default class DebugAiModal extends Component {
}

loadApiRequestInfo() {
ajax(
`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`
).then((result) => {
this.info = result;
});
ajax(`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`)
.then((result) => {
this.info = result;
})
.catch((e) => {
popupAjaxError(e);
});
}

get requestActive() {
Expand Down Expand Up @@ -147,6 +172,22 @@ export default class DebugAiModal extends Component {
@action={{this.copyResponse}}
@label="discourse_ai.ai_bot.debug_ai_modal.copy_response"
/>
{{#if this.info.prev_log_id}}
<DButton
class="btn"
@icon="angles-left"
@action={{this.prevLog}}
@label="discourse_ai.ai_bot.debug_ai_modal.previous_log"
/>
{{/if}}
{{#if this.info.next_log_id}}
<DButton
class="btn"
@icon="angles-right"
@action={{this.nextLog}}
@label="discourse_ai.ai_bot.debug_ai_modal.next_log"
/>
{{/if}}
<span class="ai-debut-modal__just-copied">{{this.justCopiedText}}</span>
</:footer>
</DModal>
Expand Down
2 changes: 2 additions & 0 deletions config/locales/client.en.yml
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,8 @@ en:
response_tokens: "Response tokens:"
request: "Request"
response: "Response"
next_log: "Next"
previous_log: "Previous"

share_full_topic_modal:
title: "Share Conversation Publicly"
Expand Down
1 change: 1 addition & 0 deletions config/routes.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
get "bot-username" => "bot#show_bot_username"
get "post/:post_id/show-debug-info" => "bot#show_debug_info"
get "show-debug-info/:id" => "bot#show_debug_info_by_id"
post "post/:post_id/stop-streaming" => "bot#stop_streaming_response"
end

Expand Down
19 changes: 12 additions & 7 deletions lib/ai_bot/bot.rb
Original file line number Diff line number Diff line change
Expand Up @@ -100,30 +100,35 @@ def reply(context, &update_blk)
llm_kwargs[:top_p] = persona.top_p if persona.top_p

needs_newlines = false
tools_ran = 0

while total_completions <= MAX_COMPLETIONS && ongoing_chain
tool_found = false
force_tool_if_needed(prompt, context)

result =
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context)
tool = nil if tools_ran >= MAX_TOOLS

if (tools.present?)
if tool.present?
tool_found = true
# a bit hacky, but extra newlines do no harm
if needs_newlines
update_blk.call("\n\n", cancel)
needs_newlines = false
end

tools[0..MAX_TOOLS].each do |tool|
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
ongoing_chain &&= tool.chain_next_response?
end
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
tools_ran += 1
ongoing_chain &&= tool.chain_next_response?
else
needs_newlines = true
update_blk.call(partial, cancel)
if partial.is_a?(DiscourseAi::Completions::ToolCall)
Rails.logger.warn("DiscourseAi: Tool not found: #{partial.name}")
else
update_blk.call(partial, cancel)
end
end
end

Expand Down
21 changes: 7 additions & 14 deletions lib/ai_bot/personas/persona.rb
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,16 @@ def craft_prompt(context, llm: nil)
prompt
end

def find_tools(partial, bot_user:, llm:, context:)
return [] if !partial.include?("</invoke>")

parsed_function = Nokogiri::HTML5.fragment(partial)
parsed_function
.css("invoke")
.map do |fragment|
tool_instance(fragment, bot_user: bot_user, llm: llm, context: context)
end
.compact
def find_tool(partial, bot_user:, llm:, context:)
return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall)
tool_instance(partial, bot_user: bot_user, llm: llm, context: context)
end

protected

def tool_instance(parsed_function, bot_user:, llm:, context:)
function_id = parsed_function.at("tool_id")&.text
function_name = parsed_function.at("tool_name")&.text
def tool_instance(tool_call, bot_user:, llm:, context:)
function_id = tool_call.id
function_name = tool_call.name
return nil if function_name.nil?

tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
Expand All @@ -224,7 +217,7 @@ def tool_instance(parsed_function, bot_user:, llm:, context:)
arguments = {}
tool_klass.signature[:parameters].to_a.each do |param|
name = param[:name]
value = parsed_function.at(name)&.text
value = tool_call.parameters[name.to_sym]

if param[:type] == "array" && value
value =
Expand Down
118 changes: 56 additions & 62 deletions lib/completions/anthropic_message_processor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,87 +13,81 @@ def initialize(name, id)
def append(json)
@raw_json << json
end

def to_tool_call
parameters = JSON.parse(raw_json, symbolize_names: true)
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
end
end

attr_reader :tool_calls, :input_tokens, :output_tokens

def initialize(streaming_mode:)
@streaming_mode = streaming_mode
@tool_calls = []
@current_tool_call = nil
end

def to_xml_tool_calls(function_buffer)
return function_buffer if @tool_calls.blank?

function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
<function_calls>
</function_calls>
TEXT

@tool_calls.each do |tool_call|
node =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
),
)

params = JSON.parse(tool_call.raw_json, symbolize_names: true)
xml =
params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}</#{name}>" }.join("\n")
def to_tool_calls
@tool_calls.map { |tool_call| tool_call.to_tool_call }
end

node.at("tool_name").content = tool_call.name
node.at("tool_id").content = tool_call.id
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
def process_streamed_message(parsed)
result = nil
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
tool_name = parsed.dig(:content_block, :name)
tool_id = parsed.dig(:content_block, :id)
result = @current_tool_call.to_tool_call if @current_tool_call
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
if @current_tool_call
tool_delta = parsed.dig(:delta, :partial_json).to_s
@current_tool_call.append(tool_delta)
else
result = parsed.dig(:delta, :text).to_s
end
elsif parsed[:type] == "content_block_stop"
if @current_tool_call
result = @current_tool_call.to_tool_call
@current_tool_call = nil
end
elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta"
@output_tokens =
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
elsif parsed[:type] == "message_stop"
# bedrock has this ...
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
end
end

function_buffer
result
end

def process_message(payload)
result = ""
parsed = JSON.parse(payload, symbolize_names: true)
parsed = payload
parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String)

if @streaming_mode
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
tool_name = parsed.dig(:content_block, :name)
tool_id = parsed.dig(:content_block, :id)
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
if @tool_calls.present?
result = parsed.dig(:delta, :partial_json).to_s
@tool_calls.last.append(result)
else
result = parsed.dig(:delta, :text).to_s
end
elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta"
@output_tokens =
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
elsif parsed[:type] == "message_stop"
# bedrock has this ...
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
end
end
else
content = parsed.dig(:content)
if content.is_a?(Array)
tool_call = content.find { |c| c[:type] == "tool_use" }
if tool_call
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
@tool_calls.last.append(tool_call[:input].to_json)
else
result = parsed.dig(:content, 0, :text).to_s
content = parsed.dig(:content)
if content.is_a?(Array)
result =
content.map do |data|
if data[:type] == "tool_use"
call = AnthropicToolCall.new(data[:name], data[:id])
call.append(data[:input].to_json)
call.to_tool_call
else
data[:text]
end
end
end

@input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens)
end

@input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens)

result
end
end
19 changes: 17 additions & 2 deletions lib/completions/dialects/ollama.rb
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,23 @@ def enable_native_tool?
def user_msg(msg)
user_message = { role: "user", content: msg[:content] }

# TODO: Add support for user messages with empbeded user ids
# TODO: Add support for user messages with attachments
encoded_uploads = prompt.encoded_uploads(msg)
if encoded_uploads.present?
images =
encoded_uploads
.map do |upload|
if upload[:mime_type].start_with?("image/")
upload[:base64]
else
nil
end
end
.compact

user_message[:images] = images if images.present?
end

# TODO: Add support for user messages with embedded user ids

user_message
end
Expand Down
Loading

0 comments on commit e817b7d

Please sign in to comment.