Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): support config reranker #211

Merged
merged 3 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,5 @@ TIDB_USER=
TIDB_PASSWORD=
TIDB_DATABASE=

# Replace with your own Jina AI API key
# You can get one from https://jina.ai/reranker/
JINAAI_API_KEY=

# *** DO NOT CHANGE BELOW CONFIGURATIONS UNLESS YOU KNOW WHAT YOU ARE DOING
DSP_CACHEBOOL=false
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ A conversational search tool based on GraphRAG (Knowledge Graph) that built on t
>
> 1. Set up a [TiDB Serverless cluster](https://docs.pingcap.com/tidbcloud/tidb-cloud-quickstart).
> 2. Install [Docker Compose](https://docs.docker.com/compose/install/).
> 3. Jina AI API key, get one from [Jina AI](https://jina.ai/reranker/).

1. Clone the repository:

Expand All @@ -52,7 +51,6 @@ A conversational search tool based on GraphRAG (Knowledge Graph) that built on t

Replace the following placeholders with your own values:
- `SECRET_KEY`: you can generate a random secret key using `python3 -c "import secrets; print(secrets.token_urlsafe(32))"`
- `JINAAI_API_KEY`: get one from [Jina AI](https://jina.ai/reranker/)
- `TIDB_HOST`, `TIDB_USER`, `TIDB_PASSWORD` and `TIDB_DATABASE`: get them from your [TiDB Serverless cluster](https://tidbcloud.com/)

- Note: TiDB Serverless will provide a default database name called `test`, if you want to use another database name, you need to create a new database in the TiDB Serverless console.
Expand Down
20 changes: 12 additions & 8 deletions backend/app/alembic/versions/bd17a4ebccc5_.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,34 @@
Create Date: 2024-08-08 01:20:42.069228

"""

from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from tidb_vector.sqlalchemy import VectorType


# revision identifiers, used by Alembic.
revision = 'bd17a4ebccc5'
down_revision = 'a8c79553c9f6'
revision = "bd17a4ebccc5"
down_revision = "a8c79553c9f6"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('data_sources', sa.Column('deleted_at', sa.DateTime(), nullable=True))
op.drop_index('source_uri', table_name='documents')
op.add_column('relationships', sa.Column('chunk_id', sqlmodel.sql.sqltypes.GUID(), nullable=True))
op.add_column("data_sources", sa.Column("deleted_at", sa.DateTime(), nullable=True))
op.drop_index("source_uri", table_name="documents")
op.add_column(
"relationships",
sa.Column("chunk_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('relationships', 'chunk_id')
op.create_index('source_uri', 'documents', ['source_uri'], unique=True)
op.drop_column('data_sources', 'deleted_at')
op.drop_column("relationships", "chunk_id")
op.create_index("source_uri", "documents", ["source_uri"], unique=True)
op.drop_column("data_sources", "deleted_at")
# ### end Alembic commands ###
66 changes: 66 additions & 0 deletions backend/app/alembic/versions/e32f1e546eec_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""empty message

Revision ID: e32f1e546eec
Revises: bd17a4ebccc5
Create Date: 2024-08-08 03:55:14.042290

"""

from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from tidb_vector.sqlalchemy import VectorType
from app.models.base import AESEncryptedColumn


# revision identifiers, used by Alembic.
revision = "e32f1e546eec"
down_revision = "bd17a4ebccc5"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"reranker_models",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
sa.Column(
"provider",
sa.Enum("JINA", "COHERE", name="rerankerprovider"),
nullable=False,
),
sa.Column(
"model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False
),
sa.Column("top_n", sa.Integer(), nullable=False),
sa.Column("config", sa.JSON(), nullable=True),
sa.Column("is_default", sa.Boolean(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("credentials", AESEncryptedColumn(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("chat_engines", sa.Column("reranker_id", sa.Integer(), nullable=True))
op.create_foreign_key(
None, "chat_engines", "reranker_models", ["reranker_id"], ["id"]
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_engines", "reranker_id")
op.drop_table("reranker_models")
# ### end Alembic commands ###
121 changes: 119 additions & 2 deletions backend/app/api/admin_routes/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,26 @@
from fastapi_pagination.ext.sqlmodel import paginate
from sqlmodel import select, update
from langfuse import Langfuse
from llama_index.core.schema import NodeWithScore, TextNode

from app.core.config import settings
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.rag.llm_option import admin_llm_options, LLMOption
from app.rag.embed_model_option import admin_embed_model_options, EmbeddingModelOption
from app.rag.chat_config import get_llm, get_embedding_model
from app.rag.reranker_model_option import (
admin_reranker_model_options,
RerankerModelOption,
)
from app.rag.chat_config import get_llm, get_embedding_model, get_reranker_model
from app.models import (
ChatEngine,
LLM,
AdminLLM,
EmbeddingModel,
AdminEmbeddingModel,
RerankerModel,
AdminRerankerModel,
)
from app.site_settings import SiteSetting

router = APIRouter()

Expand Down Expand Up @@ -213,3 +219,114 @@ def test_langfuse(
success = False
error = str(e)
return LangfuseTestResult(success=success, error=error)


@router.get("/admin/reranker-models/options")
def get_reranker_model_options(
user: CurrentSuperuserDep,
) -> List[RerankerModelOption]:
return admin_reranker_model_options


@router.post("/admin/reranker-models/test")
def test_reranker_model(
db_reranker_model: RerankerModel,
user: CurrentSuperuserDep,
) -> LLMTestResult:
try:
reranker = get_reranker_model(
provider=db_reranker_model.provider,
model=db_reranker_model.model,
# for testing purpose, we only rerank 2 nodes
top_n=2,
config=db_reranker_model.config,
credentials=db_reranker_model.credentials,
)
nodes = reranker.postprocess_nodes(
nodes=[
NodeWithScore(
node=TextNode(
text="TiDB is a distributed SQL database.",
),
score=0.8,
),
NodeWithScore(
node=TextNode(
text="TiDB is compatible with MySQL protocol.",
),
score=0.6,
),
NodeWithScore(
node=TextNode(
text="TiFlash is a columnar storage engine.",
),
score=0.4,
),
],
query_str="What is TiDB?",
)
success = True
error = ""
except Exception as e:
success = False
error = str(e)
return LLMTestResult(success=success, error=error)


@router.get("/admin/reranker-models")
def list_reranker_models(
session: SessionDep,
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> Page[AdminRerankerModel]:
return paginate(
session,
select(RerankerModel).order_by(RerankerModel.created_at.desc()),
params,
)


@router.post("/admin/reranker-models")
def create_reranker_model(
reranker_model: RerankerModel,
session: SessionDep,
user: CurrentSuperuserDep,
) -> AdminRerankerModel:
session.add(reranker_model)
session.commit()
session.refresh(reranker_model)
return reranker_model


@router.get("/admin/reranker-models/{reranker_model_id}")
def get_reranker_model_detail(
reranker_model_id: int,
session: SessionDep,
user: CurrentSuperuserDep,
) -> AdminRerankerModel:
reranker_model = session.get(RerankerModel, reranker_model_id)
if reranker_model is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Reranker model not found"
)
return reranker_model


@router.delete("/admin/reranker-models/{reranker_model_id}")
def delete_reranker_model(
reranker_id: int,
session: SessionDep,
user: CurrentSuperuserDep,
):
reranker_model = session.get(RerankerModel, reranker_id)
if reranker_model is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Reranker model not found"
)
session.exec(
update(ChatEngine)
.where(ChatEngine.reranker_id == reranker_id)
.values(reranker_id=None)
)
session.delete(reranker_model)
session.commit()
11 changes: 5 additions & 6 deletions backend/app/api/routes/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from app.api.deps import SessionDep
from app.site_settings import SiteSetting
from app.rag.chat import check_rag_required_config
from app.rag.chat import check_rag_required_config, check_rag_optional_config

router = APIRouter()

Expand All @@ -28,6 +28,7 @@ class RequiredConfigStatus(BaseModel):

class OptionalConfigStatus(BaseModel):
langfuse: bool
default_reranker: bool


class SystemConfigStatusResponse(BaseModel):
Expand All @@ -40,17 +41,15 @@ def system_bootstrap_status(session: SessionDep) -> SystemConfigStatusResponse:
has_default_llm, has_default_embedding_model, has_datasource = (
check_rag_required_config(session)
)
langfuse, default_reranker = check_rag_optional_config(session)
return SystemConfigStatusResponse(
required=RequiredConfigStatus(
default_llm=has_default_llm,
default_embedding_model=has_default_embedding_model,
datasource=has_datasource,
),
optional=OptionalConfigStatus(
langfuse=bool(
SiteSetting.langfuse_host
and SiteSetting.langfuse_secret_key
and SiteSetting.langfuse_public_key
)
langfuse=langfuse,
default_reranker=default_reranker,
),
)
1 change: 1 addition & 0 deletions backend/app/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .data_source import DataSource, DataSourceType
from .llm import LLM, AdminLLM
from .embed_model import EmbeddingModel, AdminEmbeddingModel
from .reranker_model import RerankerModel, AdminRerankerModel
6 changes: 6 additions & 0 deletions backend/app/models/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class ChatEngine(UpdatableBaseModel, table=True):
"foreign_keys": "ChatEngine.fast_llm_id",
},
)
reranker_id: Optional[int] = Field(foreign_key="reranker_models.id", nullable=True)
reranker: "RerankerModel" = SQLRelationship(
sa_relationship_kwargs={
"foreign_keys": "ChatEngine.reranker_id",
},
)
is_default: bool = Field(default=False)
deleted_at: Optional[datetime] = Field(default=None, sa_column=Column(DateTime))

Expand Down
2 changes: 1 addition & 1 deletion backend/app/models/knowledge_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class RelationshipBase(SQLModel):
source_entity_id: int = Field(foreign_key="entities.id")
target_entity_id: int = Field(foreign_key="entities.id")
last_modified_at: Optional[datetime] = Field(sa_column=Column(DateTime))
chunk_id: Optional[UUID] = Field(default=None)


class Relationship(RelationshipBase, table=True):
Expand All @@ -82,7 +83,6 @@ class Relationship(RelationshipBase, table=True):
"lazy": "joined",
},
)
chunk_id: UUID = Field(nullable=True)

__tablename__ = "relationships"

Expand Down
26 changes: 26 additions & 0 deletions backend/app/models/reranker_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional, Any

from sqlmodel import Field, Column, JSON

from .base import UpdatableBaseModel, AESEncryptedColumn
from app.types import RerankerProvider


class BaseRerankerModel(UpdatableBaseModel):
name: str = Field(max_length=64)
provider: RerankerProvider
model: str = Field(max_length=256)
top_n: int = Field(default=10)
config: dict | list | None = Field(sa_column=Column(JSON), default={})
is_default: bool = Field(default=False)


class RerankerModel(BaseRerankerModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
credentials: Any = Field(sa_column=Column(AESEncryptedColumn, nullable=True))

__tablename__ = "reranker_models"


class AdminRerankerModel(BaseRerankerModel):
id: int
Loading