Skip to content

Commit

Permalink
Anthropic implementation of partial streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
SamSaffron committed Nov 12, 2024
1 parent e817b7d commit 019c728
Show file tree
Hide file tree
Showing 10 changed files with 781 additions and 25 deletions.
80 changes: 76 additions & 4 deletions lib/completions/anthropic_message_processor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,92 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
class AnthropicToolCall
attr_reader :name, :raw_json, :id

def initialize(name, id)
def initialize(name, id, partial_tool_calls: false)
@name = name
@id = id
@raw_json = +""
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
@streaming_parser = ToolCallProgressTracker.new(self) if partial_tool_calls
end

def append(json)
@raw_json << json
@streaming_parser << json if @streaming_parser
end

def notify_progress(key, value)
@tool_call.partial = true
@tool_call.parameters[key.to_sym] = value
@has_new_data = true
end

def has_partial?
@has_new_data
end

def partial_tool_call
@has_new_data = false
@tool_call
end

def to_tool_call
parameters = JSON.parse(raw_json, symbolize_names: true)
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
@tool_call.partial = false
@tool_call.parameters = parameters
@tool_call
end
end

class ToolCallProgressTracker
attr_reader :current_key, :current_value, :tool_call

def initialize(tool_call)
@tool_call = tool_call
@current_key = nil
@current_value = nil
@parser = DiscourseAi::Completions::JsonStreamingParser.new

@parser.key do |k|
@current_key = k
@current_value = nil
end

@parser.value do |v|
tool_call.notify_progress(@current_key, v) if @current_key
end
end

def <<(json)
# llm could send broken json
# in that case just deal with it later
# don't stream
return if @broken

begin
@parser << json
rescue DiscourseAi::Utils::ParserError
@broken = true
return
end

if @parser.state == :start_string && @current_key
# this is is worth notifying
tool_call.notify_progress(@current_key, @parser.buf)
end

if @parser.state == :end_value
@current_key = nil
end
end
end

attr_reader :tool_calls, :input_tokens, :output_tokens

def initialize(streaming_mode:)
def initialize(streaming_mode:, partial_tool_calls: false)
@streaming_mode = streaming_mode
@tool_calls = []
@current_tool_call = nil
@partial_tool_calls = partial_tool_calls
end

def to_tool_calls
Expand All @@ -38,11 +102,19 @@ def process_streamed_message(parsed)
tool_name = parsed.dig(:content_block, :name)
tool_id = parsed.dig(:content_block, :id)
result = @current_tool_call.to_tool_call if @current_tool_call
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
@current_tool_call =
AnthropicToolCall.new(
tool_name,
tool_id,
partial_tool_calls: @partial_tool_calls,
) if tool_name
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
if @current_tool_call
tool_delta = parsed.dig(:delta, :partial_json).to_s
@current_tool_call.append(tool_delta)
if @current_tool_call.has_partial?
result = @current_tool_call.partial_tool_call
end
else
result = parsed.dig(:delta, :text).to_s
end
Expand Down
2 changes: 1 addition & 1 deletion lib/completions/endpoints/anthropic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def decode(response_data)

def processor
@processor ||=
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode, partial_tool_calls: partial_tool_calls)
end

def has_tool?(_response_data)
Expand Down
4 changes: 4 additions & 0 deletions lib/completions/endpoints/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ module DiscourseAi
module Completions
module Endpoints
class Base
attr_reader :partial_tool_calls

CompletionFailed = Class.new(StandardError)
TIMEOUT = 60

Expand Down Expand Up @@ -58,8 +60,10 @@ def perform_completion!(
model_params = {},
feature_name: nil,
feature_context: nil,
partial_tool_calls: false,
&blk
)
@partial_tool_calls = partial_tool_calls
model_params = normalize_model_params(model_params)
orig_blk = blk

Expand Down
3 changes: 2 additions & 1 deletion lib/completions/endpoints/canned_response.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def perform_completion!(
_user,
_model_params,
feature_name: nil,
feature_context: nil
feature_context: nil,
partial_tool_calls: false
)
@dialect = dialect
response = responses[completions]
Expand Down
3 changes: 2 additions & 1 deletion lib/completions/endpoints/fake.rb
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def perform_completion!(
user,
model_params = {},
feature_name: nil,
feature_context: nil
feature_context: nil,
partial_tool_calls: false
)
last_call = { dialect: dialect, user: user, model_params: model_params }
self.class.last_call = last_call
Expand Down
1 change: 1 addition & 0 deletions lib/completions/endpoints/open_ai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def perform_completion!(
model_params = {},
feature_name: nil,
feature_context: nil,
partial_tool_calls: false,
&blk
)
if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given?
Expand Down
Loading

0 comments on commit 019c728

Please sign in to comment.