Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIChat refactor to support standalone and persistent conversations #24921

Merged
merged 32 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5545b1a
AI Chat: Introduce AIChatService, ConversationHandler, and direct bin…
petemill Jul 16, 2024
b7d2662
ConversationHandler doesn't need to deal with navigation ID
petemill Sep 12, 2024
976e6a2
test fix
petemill Sep 12, 2024
c1100a8
AIChatTabHelper params instead of multiple test. Always trim content.
petemill Sep 13, 2024
7ec5efa
test and review feedback - comments, id->uuid, page-navigation-tests
petemill Sep 13, 2024
657f79e
fix for android build
petemill Sep 13, 2024
101364c
fix same-document back/forward navigation by considering page title c…
petemill Sep 13, 2024
f906c8b
ios refactor
petemill Sep 17, 2024
51f1357
Fix compiling on iOS. Fix Service registration crash.
Brandon-T Sep 17, 2024
86abbe2
Fix crashes on iOS. Fix logic so AIChat on iOS works correctly. Fix m…
Brandon-T Sep 17, 2024
9d2e1e7
fix ConversationHandler::GenerateQuestions, refactor non-conversation…
petemill Sep 18, 2024
e70fe47
feedback
petemill Sep 18, 2024
474284f
fix AIChatRenderViewContextMenuBrowserTest
petemill Sep 18, 2024
a3e4367
don't wait for client connection before submitting human message
petemill Sep 18, 2024
e07ff52
AIChatService::MaybeAssociateContentWithConversation
petemill Sep 18, 2024
d173540
feedback
petemill Sep 18, 2024
313781e
android HandleVoiceRecognition now optionally passes ConversationId t…
petemill Sep 18, 2024
c505f43
feedback
petemill Sep 18, 2024
197b5cf
fix ModelService migrating from chat-claude-instant default model pre…
petemill Sep 18, 2024
a47c7ed
feedback
petemill Sep 19, 2024
288fa5d
format
petemill Sep 19, 2024
de3761e
rebase fixes
petemill Sep 19, 2024
67b9ff1
AIChatTabHelper refine and test retry logic
petemill Sep 20, 2024
498af7a
fix android compile
petemill Sep 20, 2024
19f155c
fix android again
petemill Sep 20, 2024
d156c64
no channel_info new string
petemill Sep 20, 2024
c40131f
associatedcontentdriver - remove is_page_text_fetch_in_progress_
petemill Sep 20, 2024
95eac52
ConversationHandler::HasAnyHistory ignores staged entries
petemill Sep 20, 2024
c8588e9
AIChatService: erase from content_conversation map, and test it
petemill Sep 20, 2024
e1bac5c
MaybeUnlink should check if client is connected
petemill Sep 20, 2024
23b728d
fix android again?
petemill Sep 20, 2024
5f89085
ChromeAutocompleteProviderClient should check AIChatService exists
petemill Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ private static void openURL(String url) {
}

@CalledByNative
private static void handleVoiceRecognition(
WebContents chatWindowWebContents, WebContents contextWebContents) {
private static void handleVoiceRecognition(WebContents webContents, String conversation_uuid) {
new BraveLeoVoiceRecognitionHandler(
chatWindowWebContents.getTopLevelNativeWindow(), contextWebContents)
webContents.getTopLevelNativeWindow(), webContents, conversation_uuid)
.startVoiceRecognition();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,12 @@ public static void verifySubscription(Callback callback) {
}

public static void openLeoQuery(
WebContents webContents, String query, boolean openLeoChatWindow) {
WebContents webContents,
String conversationUuid,
String query,
boolean openLeoChatWindow) {
try {
BraveLeoUtilsJni.get().openLeoQuery(webContents, query);
BraveLeoUtilsJni.get().openLeoQuery(webContents, conversationUuid, query);
if (openLeoChatWindow) {
BraveActivity activity = BraveActivity.getBraveActivity();
activity.openBraveLeo();
Expand Down Expand Up @@ -106,6 +109,6 @@ public static void bringMainActivityOnTop() {

@NativeMethods
public interface Natives {
void openLeoQuery(WebContents webContents, String query);
void openLeoQuery(WebContents webContents, String conversationUuid, String query);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class BraveLeoVoiceRecognitionHandler {
private static final String TAG = "LeoVoiceRecognition";
private WindowAndroid mWindowAndroid;
private WebContents mContextWebContents;
private String mConversationUuid;

/** Callback for when we receive voice search results after initiating voice recognition. */
class VoiceRecognitionCompleteCallback implements WindowAndroid.IntentCallback {
Expand Down Expand Up @@ -66,7 +67,8 @@ private void handleTranscriptionResult(Intent data) {
if (TextUtils.isEmpty(topResultQuery)) {
return;
}
BraveLeoUtils.openLeoQuery(mContextWebContents, topResultQuery, false);
BraveLeoUtils.openLeoQuery(
mContextWebContents, mConversationUuid, topResultQuery, false);
}
}

Expand Down Expand Up @@ -96,9 +98,10 @@ public float getConfidence() {
}

public BraveLeoVoiceRecognitionHandler(
WindowAndroid windowAndroid, WebContents contextWebContents) {
WindowAndroid windowAndroid, WebContents contextWebContents, String conversationUuid) {
mWindowAndroid = windowAndroid;
mContextWebContents = contextWebContents;
mConversationUuid = conversationUuid;
}

private List<BraveLeoVoiceRecognitionHandler.VoiceResult> convertBundleToVoiceResults(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ public boolean isLeoEnabled() {
}

@Override
public void openLeoQuery(WebContents webContents, String query) {
public void openLeoQuery(WebContents webContents, String conversationUuid, String query) {
mDelegate.clearOmniboxFocus();
BraveLeoUtils.openLeoQuery(webContents, query, true);
BraveLeoUtils.openLeoQuery(webContents, conversationUuid, query, true);
}

@Override
Expand All @@ -166,7 +166,7 @@ void onVoiceResults(@Nullable List<VoiceRecognitionHandler.VoiceResult> voiceRes
// Remove the start word from the query and process it.
topResultQuery =
topResultQuery.substring(LEO_START_WORD_UPPER_CASE.length()).trim();
openLeoQuery(tab.getWebContents(), topResultQuery);
openLeoQuery(tab.getWebContents(), "", topResultQuery);

// Clear the voice results to prevent the query from being processed by Chromium
// since it's already handled by Leo.
Expand Down
1 change: 1 addition & 0 deletions browser/DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ include_rules = [
"+media/webrtc", # For webrtc media switches.
"+mojo/public",
"+net",
"+printing/buildflags/buildflags.h",
"+sandbox/mac",
"+sandbox/policy",
"+services/audio/public",
Expand Down
18 changes: 10 additions & 8 deletions browser/ai_chat/ai_chat_browsertests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,21 @@ class AiChatBrowserTest : public InProcessBrowserTest {
std::string FetchPageContent() {
std::string content;
base::RunLoop run_loop;
ai_chat::FetchPageContent(
browser()->tab_strip_model()->GetActiveWebContents(), "",
base::BindLambdaForTesting(
[&run_loop, &content](std::string page_content, bool is_video,
std::string invalidation_token) {
content = std::move(page_content);
run_loop.Quit();
}));
page_content_fetcher_ = std::make_unique<PageContentFetcher>(
browser()->tab_strip_model()->GetActiveWebContents());
page_content_fetcher_->FetchPageContent(
"", base::BindLambdaForTesting(
[&run_loop, &content](std::string page_content, bool is_video,
std::string invalidation_token) {
content = std::move(page_content);
run_loop.Quit();
}));
run_loop.Run();
return content;
}

private:
std::unique_ptr<PageContentFetcher> page_content_fetcher_;
content::ContentMockCertVerifier mock_cert_verifier_;
net::EmbeddedTestServer https_server_{net::EmbeddedTestServer::TYPE_HTTPS};
};
Expand Down
2 changes: 1 addition & 1 deletion browser/ai_chat/ai_chat_metrics_browsertest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ IN_PROC_BROWSER_TEST_F(AIChatMetricsTest, ContextMenuActions) {
ai_chat_metrics_->RecordEnabled(
true, true,
base::BindLambdaForTesting(
[&](mojom::PageHandler::GetPremiumStatusCallback callback) {
[&](mojom::Service::GetPremiumStatusCallback callback) {
std::move(callback).Run(mojom::PremiumStatus::Active, nullptr);
}));
histogram_tester_.ExpectUniqueSample(kMostUsedContextMenuActionHistogramName,
Expand Down
125 changes: 62 additions & 63 deletions browser/ai_chat/ai_chat_render_view_context_menu_browsertest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,22 @@

#include "base/path_service.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "brave/app/brave_command_ids.h"
#include "brave/browser/ai_chat/ai_chat_service_factory.h"
#include "brave/browser/ui/brave_browser.h"
#include "brave/browser/ui/sidebar/sidebar_controller.h"
#include "brave/browser/ui/sidebar/sidebar_model.h"
#include "brave/components/ai_chat/content/browser/ai_chat_tab_helper.h"
#include "brave/components/ai_chat/core/browser/ai_chat_service.h"
#include "brave/components/ai_chat/core/browser/engine/engine_consumer.h"
#include "brave/components/ai_chat/core/browser/engine/mock_engine_consumer.h"
#include "brave/components/ai_chat/core/browser/engine/mock_remote_completion_client.h"
#include "brave/components/ai_chat/core/browser/utils.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
#include "brave/components/constants/brave_paths.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/renderer_context_menu/render_view_context_menu.h"
#include "chrome/browser/renderer_context_menu/render_view_context_menu_browsertest_util.h"
#include "chrome/browser/renderer_context_menu/render_view_context_menu_test_util.h"
#include "chrome/browser/ui/browser.h"
#include "chrome/browser/ui/tabs/tab_strip_model.h"
Expand All @@ -44,40 +47,21 @@ using ::testing::_;

namespace ai_chat {

class MockEngineConsumer : public EngineConsumer {
public:
MOCK_METHOD(void,
GenerateQuestionSuggestions,
(const bool&, const std::string&, SuggestedQuestionsCallback),
(override));
MOCK_METHOD(void,
GenerateAssistantResponse,
(const bool&,
const std::string&,
const ConversationHistory&,
const std::string&,
GenerationDataCallback,
GenerationCompletedCallback),
(override));
MOCK_METHOD(void,
GenerateRewriteSuggestion,
(std::string,
const std::string&,
GenerationDataCallback,
GenerationCompletedCallback),
(override));
MOCK_METHOD(void, SanitizeInput, (std::string&), (override));
MOCK_METHOD(void, ClearAllQueries, (), (override));
MOCK_METHOD(void,
UpdateModelOptions,
(const mojom::ModelOptions&),
(override));
};
namespace {

void ExecuteRewriteCommand(RenderViewContextMenu* context_menu) {
// Calls EngineConsumer::GenerateRewriteSuggestion
context_menu->ExecuteCommand(IDC_AI_CHAT_CONTEXT_SHORTEN, 0);
context_menu->Cancel();
}

} // namespace

class AIChatRenderViewContextMenuBrowserTest : public InProcessBrowserTest {
public:
AIChatRenderViewContextMenuBrowserTest()
: https_server_(net::EmbeddedTestServer::TYPE_HTTPS) {}
: ai_engine_(std::make_unique<MockEngineConsumer>()),
https_server_(net::EmbeddedTestServer::TYPE_HTTPS) {}

~AIChatRenderViewContextMenuBrowserTest() override = default;

Expand Down Expand Up @@ -114,28 +98,13 @@ class AIChatRenderViewContextMenuBrowserTest : public InProcessBrowserTest {

void TestRewriteInPlace(
content::WebContents* web_contents,
MockEngineConsumer* mock_engine,
const std::string& element_id,
const std::string& expected_selected_text,
const std::vector<std::string>& received_data,
base::expected<std::string, mojom::APIError> completed_result,
const std::string& expected_updated_text) {
base::RunLoop run_loop;
// Verify that rewrite is requested
EXPECT_CALL(*mock_engine, GenerateRewriteSuggestion(_, _, _, _))
.WillOnce([&](std::string text, const std::string& question,
EngineConsumer::GenerationDataCallback data_callback,
EngineConsumer::GenerationCompletedCallback callback) {
ASSERT_TRUE(callback);
ASSERT_TRUE(data_callback);
for (const auto& data : received_data) {
auto event = mojom::ConversationEntryEvent::NewCompletionEvent(
mojom::CompletionEvent::New(data));
data_callback.Run(std::move(event));
}
std::move(callback).Run(completed_result);
run_loop.Quit();
});
MockEngineConsumer* ai_engine;

// Select text in the element and create context menu to execute a rewrite
// command.
Expand All @@ -154,15 +123,42 @@ class AIChatRenderViewContextMenuBrowserTest : public InProcessBrowserTest {
base::StringPrintf("getRectY('%s')", element_id.c_str()))
.ExtractInt();

// Calls ConversationDriver::SubmitSelectedText
ContextMenuNotificationObserver context_menu_observer(
IDC_AI_CHAT_CONTEXT_SHORTEN);
RenderViewContextMenu::RegisterMenuShownCallbackForTesting(
base::BindLambdaForTesting([&](RenderViewContextMenu* context_menu) {
auto* brave_context_menu =
static_cast<BraveRenderViewContextMenu*>(context_menu);
brave_context_menu->SetAIEngineForTesting(
std::make_unique<MockEngineConsumer>());
ai_engine = static_cast<MockEngineConsumer*>(
brave_context_menu->GetAIEngineForTesting());
// Verify that rewrite is requested
EXPECT_CALL(*ai_engine, GenerateRewriteSuggestion(_, _, _, _))
.WillOnce(
[&](std::string text, const std::string& question,
EngineConsumer::GenerationDataCallback data_callback,
EngineConsumer::GenerationCompletedCallback callback) {
ASSERT_TRUE(callback);
ASSERT_TRUE(data_callback);
for (const auto& data : received_data) {
auto event =
mojom::ConversationEntryEvent::NewCompletionEvent(
mojom::CompletionEvent::New(data));
data_callback.Run(std::move(event));
}
std::move(callback).Run(completed_result);
run_loop.Quit();
});
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&ExecuteRewriteCommand, context_menu));
darkdh marked this conversation as resolved.
Show resolved Hide resolved
}));

web_contents->GetPrimaryMainFrame()
->GetRenderViewHost()
->GetWidget()
->ShowContextMenuAtPoint(gfx::Point(x, y), ui::MENU_SOURCE_MOUSE);
run_loop.Run();
testing::Mock::VerifyAndClearExpectations(mock_engine);
EXPECT_NE(ai_engine, nullptr);
testing::Mock::VerifyAndClearExpectations(ai_engine);

// Verify that the text is rewritten as expected.
std::string updated_text =
Expand All @@ -187,6 +183,7 @@ class AIChatRenderViewContextMenuBrowserTest : public InProcessBrowserTest {
}

private:
std::unique_ptr<MockEngineConsumer> ai_engine_;
content::ContentMockCertVerifier mock_cert_verifier_;
net::test_server::EmbeddedTestServer https_server_;
};
Expand All @@ -205,38 +202,40 @@ IN_PROC_BROWSER_TEST_F(AIChatRenderViewContextMenuBrowserTest, RewriteInPlace) {
ai_chat::AIChatTabHelper::FromWebContents(contents);
ASSERT_TRUE(helper);

helper->SetEngineForTesting(std::make_unique<MockEngineConsumer>());
auto* mock_engine =
static_cast<MockEngineConsumer*>(helper->GetEngineForTesting());
ConversationHandler* conversation_handler =
ai_chat::AIChatServiceFactory::GetInstance()
->GetForBrowserContext(browser()->profile())
->GetOrCreateConversationHandlerForContent(helper->GetContentId(),
helper->GetWeakPtr());
ASSERT_TRUE(conversation_handler);

// Test rewriting textarea value and verify that the response tag is ignored
// by BraveRenderViewContextMenu
TestRewriteInPlace(contents, mock_engine, "textarea", "I'm textarea.",
TestRewriteInPlace(contents, "textarea", "I'm textarea.",
{"O", "OK", "<", "</", "</r", "</re", "</response"}, "",
"OK");

// Do the same again to make sure it still works at the second time.
TestRewriteInPlace(contents, mock_engine, "textarea", "OK",
{"O", "OK", "OK2"}, "", "OK2");
TestRewriteInPlace(contents, "textarea", "OK", {"O", "OK", "OK2"}, "", "OK2");

// Select text in text input and create context menu to execute a rewrite cmd.
// Verify that the text is rewritten.
TestRewriteInPlace(contents, mock_engine, "input_text", "I'm input.",
{"O", "OK", "OK3"}, "", "OK3");
TestRewriteInPlace(contents, mock_engine, "contenteditable",
"I'm contenteditable.", {"O", "OK", "OK4"}, "", "OK4");
TestRewriteInPlace(contents, "input_text", "I'm input.", {"O", "OK", "OK3"},
"", "OK3");
TestRewriteInPlace(contents, "contenteditable", "I'm contenteditable.",
{"O", "OK", "OK4"}, "", "OK4");

// Error case handling tests and verify that the text is not rewritten.
// 1) Get error in completed callback immediately.
EXPECT_FALSE(IsAIChatSidebarActive());
TestRewriteInPlace(contents, mock_engine, "textarea", "OK2", {},
TestRewriteInPlace(contents, "textarea", "OK2", {},
base::unexpected(mojom::APIError::ConnectionIssue), "OK2");
EXPECT_TRUE(IsAIChatSidebarActive());
GetSidebarController()->DeactivateCurrentPanel();

EXPECT_FALSE(IsAIChatSidebarActive());
// 2) Get partial streaming responses then error in completed callback.
TestRewriteInPlace(contents, mock_engine, "textarea", "OK2", {"N", "O"},
TestRewriteInPlace(contents, "textarea", "OK2", {"N", "O"},
base::unexpected(mojom::APIError::ConnectionIssue), "OK2");
EXPECT_TRUE(IsAIChatSidebarActive());
}
Expand Down
Loading
Loading