Skip to content

Commit

Permalink
Fix crash in AI Chat when conversation is unassociated from page cont…
Browse files Browse the repository at this point in the history
…ent (#26495)

* AIChat ConversationHandler remove NOTREACHED_IN_MIGRATION

* ConversationHandler notifies AssociatedContentDelegate when conversation is not associated
  • Loading branch information
petemill committed Nov 12, 2024
1 parent e0a230f commit 603ded3
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void AssociatedContentDriver::AddRelatedConversation(
associated_conversations_.insert(conversation);
}

void AssociatedContentDriver::OnRelatedConversationDestroyed(
void AssociatedContentDriver::OnRelatedConversationDisassociated(
ConversationHandler* conversation) {
associated_conversations_.erase(conversation);
}
Expand Down
3 changes: 2 additions & 1 deletion components/ai_chat/core/browser/associated_content_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <set>
#include <string>
#include <string_view>
#include <vector>

#include "base/gtest_prod_util.h"
#include "base/memory/raw_ptr.h"
Expand Down Expand Up @@ -48,7 +49,7 @@ class AssociatedContentDriver

// ConversationHandler::AssociatedContentDelegate
void AddRelatedConversation(ConversationHandler* conversation) override;
void OnRelatedConversationDestroyed(
void OnRelatedConversationDisassociated(
ConversationHandler* conversation) override;
int GetContentId() const override;
GURL GetURL() const override;
Expand Down
33 changes: 20 additions & 13 deletions components/ai_chat/core/browser/conversation_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,7 @@ ConversationHandler::ConversationHandler(
}

ConversationHandler::~ConversationHandler() {
if (associated_content_delegate_) {
associated_content_delegate_->OnRelatedConversationDestroyed(this);
}
DisassociateContentDelegate();
}

void ConversationHandler::AddObserver(Observer* observer) {
Expand Down Expand Up @@ -207,8 +205,10 @@ bool ConversationHandler::HasAnyHistory() {
}

void ConversationHandler::InitEngine() {
CHECK(!model_key_.empty());
const mojom::Model* model = model_service_->GetModel(model_key_);
const mojom::Model* model = nullptr;
if (!model_key_.empty()) {
model = model_service_->GetModel(model_key_);
}
// Make sure we get a valid model, defaulting to static default or first.
if (!model) {
// It is unexpected that we get here. Dump a call stack
Expand All @@ -217,8 +217,10 @@ void ConversationHandler::InitEngine() {
base::debug::DumpWithoutCrashing();
// Use default
model = model_service_->GetModel(features::kAIModelsDefaultKey.Get());
DCHECK(model) << "The default model set via feature param does not exist";
if (!model) {
SCOPED_CRASH_KEY_STRING1024("BraveAIChatModel", "key",
features::kAIModelsDefaultKey.Get());
base::debug::DumpWithoutCrashing();
const auto& all_models = model_service_->GetModels();
// Use first if given bad default value
model = all_models.at(0).get();
Expand Down Expand Up @@ -289,7 +291,6 @@ void ConversationHandler::SetAssociatedContentDelegate(

// Unarchive content
if (archive_content_) {
associated_content_delegate_ = nullptr;
archive_content_ = nullptr;
} else if (!chat_history_.empty()) {
// Cannot associate new content with a conversation which already has
Expand All @@ -300,6 +301,7 @@ void ConversationHandler::SetAssociatedContentDelegate(
return;
}

DisassociateContentDelegate();
associated_content_delegate_ = delegate;
associated_content_delegate_->AddRelatedConversation(this);
// Default to send page contents when we have a valid contents.
Expand Down Expand Up @@ -444,12 +446,10 @@ void ConversationHandler::ChangeModel(const std::string& model_key) {
CHECK(!model_key.empty());
// Check that the key exists
auto* new_model = model_service_->GetModel(model_key);
if (!new_model) {
NOTREACHED_IN_MIGRATION()
<< "No matching model found for key: " << model_key;
return;
if (new_model) {
model_key_ = new_model->key;
}
model_key_ = new_model->key;
// Always call InitEngine, even with a bad key as we need a model
InitEngine();
}

Expand Down Expand Up @@ -553,7 +553,7 @@ void ConversationHandler::SubmitHumanConversationEntry(
// Now the conversation is committed, we can remove some unneccessary data
// if we're not associated with a page.
suggestions_.clear();
associated_content_delegate_ = nullptr;
DisassociateContentDelegate();
OnSuggestedQuestionsChanged();
// Perform generation immediately
PerformAssistantGeneration(question_part);
Expand Down Expand Up @@ -738,6 +738,13 @@ void ConversationHandler::PerformQuestionGeneration(
weak_ptr_factory_.GetWeakPtr()));
}

void ConversationHandler::DisassociateContentDelegate() {
if (associated_content_delegate_) {
associated_content_delegate_->OnRelatedConversationDisassociated(this);
associated_content_delegate_ = nullptr;
}
}

void ConversationHandler::GetAssociatedContentInfo(
GetAssociatedContentInfoCallback callback) {
BuildAssociatedContentInfo();
Expand Down
8 changes: 7 additions & 1 deletion components/ai_chat/core/browser/conversation_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ConversationHandler : public mojom::ConversationHandler,
AssociatedContentDelegate();
virtual ~AssociatedContentDelegate();
virtual void AddRelatedConversation(ConversationHandler* conversation) {}
virtual void OnRelatedConversationDestroyed(
virtual void OnRelatedConversationDisassociated(
ConversationHandler* conversation) {}
// Unique ID for the content. For browser Tab content, this should be
// a navigation ID that's re-used during back navigations.
Expand Down Expand Up @@ -297,6 +297,12 @@ class ConversationHandler : public mojom::ConversationHandler,
bool is_video,
std::string invalidation_token);

// Disassociate with the current associated content. Use this instead of
// settings associated_content_delegegate_ to nullptr to ensure that we
// inform the delegate, otherwise when this class instance is destroyed,
// the delegate will not be informed.
void DisassociateContentDelegate();

void OnGetStagedEntriesFromContent(
const std::optional<std::vector<SearchQuerySummary>>& entries);

Expand Down
22 changes: 22 additions & 0 deletions components/ai_chat/core/browser/conversation_handler_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ class MockAssociatedContent
(ConversationHandler::GetStagedEntriesCallback),
(override));

MOCK_METHOD(void,
OnRelatedConversationDisassociated,
(ConversationHandler*),
(override));

base::WeakPtr<ConversationHandler::AssociatedContentDelegate> GetWeakPtr() {
return weak_ptr_factory_.GetWeakPtr();
}
Expand Down Expand Up @@ -409,11 +414,19 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText) {
EXPECT_CALL(*engine, SanitizeInput(StrEq(selected_text)));
EXPECT_CALL(*engine, SanitizeInput(StrEq(expected_turn_text)));

// Submitting conversation entry should inform associated content
// that it is no longer associated with the conversation
// and shouldn't access the conversation because the conversation
// will not be considering the associated content for lifetime notifications.
EXPECT_CALL(*associated_content_, OnRelatedConversationDisassociated)
.Times(1);

conversation_handler_->SubmitSelectedText(
"I have spoken.", mojom::ActionType::SUMMARIZE_SELECTED_TEXT);

task_environment_.RunUntilIdle();
testing::Mock::VerifyAndClearExpectations(&client);
testing::Mock::VerifyAndClearExpectations(associated_content_.get());
// article_text_ and suggestions_ should be cleared when page content is
// unlinked.
conversation_handler_->GetAssociatedContentInfo(base::BindLambdaForTesting(
Expand Down Expand Up @@ -1360,6 +1373,15 @@ TEST_F(ConversationHandlerUnitTest, SelectedLanguage) {
testing::Mock::VerifyAndClearExpectations(engine);
}

TEST_F(ConversationHandlerUnitTest, Destuctor) {
// Verify that the conversation handler cleans up the associated content
// object when it is destroyed.
EXPECT_CALL(*associated_content_, OnRelatedConversationDisassociated)
.Times(1);
conversation_handler_.reset();
testing::Mock::VerifyAndClearExpectations(associated_content_.get());
}

class PageContentRefineTest : public ConversationHandlerUnitTest,
public testing::WithParamInterface<bool> {
public:
Expand Down

0 comments on commit 603ded3

Please sign in to comment.