From 3fcb6c24762ef084be10015e869a837846ce4edd Mon Sep 17 00:00:00 2001 From: Mini256 Date: Sun, 2 Jun 2024 18:54:07 +0800 Subject: [PATCH] feat: metadata-filter support custom llm (#148) * feat: metadata-filter support custom llm * fix --- src/app/api/test/reranker/route.ts | 2 +- .../api/v1/indexes/[name]/retrieve/route.ts | 2 +- src/core/services/llamaindex/chating.ts | 2 +- src/core/services/llamaindex/indexing.ts | 2 +- src/lib/llamaindex/builders/indices.ts | 2 +- src/lib/llamaindex/builders/llm.ts | 2 +- .../llamaindex/builders/metadata-filter.ts | 15 ++++--- src/lib/llamaindex/config/metadata-filter.ts | 2 + .../postfilters/MetadataPostFilter.ts | 40 +++++++++++-------- 9 files changed, 42 insertions(+), 27 deletions(-) diff --git a/src/app/api/test/reranker/route.ts b/src/app/api/test/reranker/route.ts index f05cc674..054a3f51 100644 --- a/src/app/api/test/reranker/route.ts +++ b/src/app/api/test/reranker/route.ts @@ -76,7 +76,7 @@ export const POST = defineHandler({ body }) => { const { query, config, llmConfig, top_k } = body; - const llm = llmConfig ? await buildLLM(llmConfig) : undefined; + const llm = llmConfig ? buildLLM(llmConfig) : undefined; const serviceContext = serviceContextFromDefaults({ llm: llm }) diff --git a/src/app/api/v1/indexes/[name]/retrieve/route.ts b/src/app/api/v1/indexes/[name]/retrieve/route.ts index 8b771bd8..d4eabc4c 100644 --- a/src/app/api/v1/indexes/[name]/retrieve/route.ts +++ b/src/app/api/v1/indexes/[name]/retrieve/route.ts @@ -36,7 +36,7 @@ export const POST = defineHandler({ const flow = await getFlow(baseRegistry); const serviceContext = serviceContextFromDefaults({ - llm: await buildLLM(llmConfig), + llm: buildLLM(llmConfig), embedModel: await buildEmbedding(index.config.embedding), }); diff --git a/src/core/services/llamaindex/chating.ts b/src/core/services/llamaindex/chating.ts index 9d4cb6b6..621e5ac9 100644 --- a/src/core/services/llamaindex/chating.ts +++ b/src/core/services/llamaindex/chating.ts @@ -135,7 +135,7 @@ export class LlamaindexChatService extends AppChatService { }); // Service context. - const llm = await buildLLM(llmConfig!, trace); + const llm = buildLLM(llmConfig!, trace); const promptHelper = new PromptHelper(llm.metadata.contextWindow); const embedModel = await buildEmbedding(this.index.config.embedding); const serviceContext = serviceContextFromDefaults({ diff --git a/src/core/services/llamaindex/indexing.ts b/src/core/services/llamaindex/indexing.ts index 2ebb1037..481172eb 100644 --- a/src/core/services/llamaindex/indexing.ts +++ b/src/core/services/llamaindex/indexing.ts @@ -47,7 +47,7 @@ export class LlamaindexIndexProvider extends DocumentIndexProvider { }); // Select and config the llm for indexing (metadata extractor). - const llm = await buildLLM(index.config.llm); + const llm = buildLLM(index.config.llm); llm.metadata.model = index.config.llm.options?.model!; // Select and config the embedding (important and immutable) diff --git a/src/lib/llamaindex/builders/indices.ts b/src/lib/llamaindex/builders/indices.ts index 87de1710..d1ec2f22 100644 --- a/src/lib/llamaindex/builders/indices.ts +++ b/src/lib/llamaindex/builders/indices.ts @@ -24,7 +24,7 @@ export async function createVectorStoreIndex (id: number) { dimensions: DEFAULT_TIDB_VECTOR_DIMENSIONS, }), serviceContext: serviceContextFromDefaults({ - llm: await buildLLM(index.config.llm), + llm: buildLLM(index.config.llm), embedModel: await buildEmbedding(index.config.embedding), }), }); diff --git a/src/lib/llamaindex/builders/llm.ts b/src/lib/llamaindex/builders/llm.ts index 3c7ac4fb..d88cfeb8 100644 --- a/src/lib/llamaindex/builders/llm.ts +++ b/src/lib/llamaindex/builders/llm.ts @@ -4,7 +4,7 @@ import {LangfuseTraceClient} from "langfuse"; import {OpenAI, Ollama} from "llamaindex"; import {Bitdeer} from "@/lib/llamaindex/llm/bitdeer"; -export async function buildLLM ({ provider, options}: LLMConfig, trace?: LangfuseTraceClient) { +export function buildLLM ({ provider, options}: LLMConfig, trace?: LangfuseTraceClient) { let baseLLM; switch (provider) { case LLMProvider.OPENAI: diff --git a/src/lib/llamaindex/builders/metadata-filter.ts b/src/lib/llamaindex/builders/metadata-filter.ts index 63e3e6a8..f766e13e 100644 --- a/src/lib/llamaindex/builders/metadata-filter.ts +++ b/src/lib/llamaindex/builders/metadata-filter.ts @@ -1,17 +1,22 @@ +import {buildLLM} from "@/lib/llamaindex/builders/llm"; import {MetadataFilterConfig} from "@/lib/llamaindex/config/metadata-filter"; import { MetadataPostFilter } from "@/lib/llamaindex/postprocessors/postfilters/MetadataPostFilter"; import {ServiceContext} from 'llamaindex'; -export function buildMetadataFilter (serviceContext: ServiceContext, { provider, options }: MetadataFilterConfig) { - switch (provider) { +export function buildMetadataFilter (serviceContext: ServiceContext, config: MetadataFilterConfig) { + switch (config.provider) { case 'default': + let llm = serviceContext.llm; + if (config.options?.llm) { + llm = buildLLM(config.options.llm); + } return new MetadataPostFilter({ - ...options, - serviceContext, + llm, + }); default: - throw new Error(`Unknown metadata filter provider: ${provider}`) + throw new Error(`Unknown metadata filter provider: ${config.provider}`) } } diff --git a/src/lib/llamaindex/config/metadata-filter.ts b/src/lib/llamaindex/config/metadata-filter.ts index 3312270b..743f38ab 100644 --- a/src/lib/llamaindex/config/metadata-filter.ts +++ b/src/lib/llamaindex/config/metadata-filter.ts @@ -1,3 +1,4 @@ +import {LLMConfigSchema} from "@/lib/llamaindex/config/llm"; import {z} from "zod"; export const metadataFilterSchema = z.object({ @@ -25,6 +26,7 @@ export enum MetadataFilterProvider { } export const DefaultMetadataFilterOptions = z.object({ + llm: LLMConfigSchema.optional(), metadata_fields: z.array(metadataFieldSchema).optional(), filters: z.array(metadataFilterSchema).optional() }); diff --git a/src/lib/llamaindex/postprocessors/postfilters/MetadataPostFilter.ts b/src/lib/llamaindex/postprocessors/postfilters/MetadataPostFilter.ts index 23c296b4..a938e6ca 100644 --- a/src/lib/llamaindex/postprocessors/postfilters/MetadataPostFilter.ts +++ b/src/lib/llamaindex/postprocessors/postfilters/MetadataPostFilter.ts @@ -1,5 +1,6 @@ import {MetadataField, MetadataFieldFilter} from "@/lib/llamaindex/config/metadata-filter"; import {BaseNodePostprocessor, NodeWithScore, ServiceContext, serviceContextFromDefaults} from "llamaindex"; +import {BaseLLM} from "llamaindex/llm/base"; import {DateTime} from "luxon"; export const defaultMetadataFilterChoicePrompt = ({metadataFields, query}: { @@ -48,41 +49,49 @@ export type MetadataFilterChoicePrompt = typeof defaultMetadataFilterChoicePromp export type MetadataPostFilterOptions = Partial; export class MetadataPostFilter implements BaseNodePostprocessor { - serviceContext: ServiceContext = serviceContextFromDefaults(); - metadataFilterChoicePrompt: MetadataFilterChoicePrompt = defaultMetadataFilterChoicePrompt; - + serviceContext: ServiceContext; + /** + * The llm model used for filters generating. + */ + llm: BaseLLM; + /** + * The prompt for metadata filter choice. + */ + metadataFilterChoicePrompt: MetadataFilterChoicePrompt; /** * The definition of metadata fields. */ - metadata_fields: MetadataField[] = []; + metadata_fields: MetadataField[]; /** * Provide the filters to apply to the search. */ - filters: MetadataFieldFilter[] | null = null; + filters: MetadataFieldFilter[] | null; constructor(init?: MetadataPostFilterOptions) { - Object.assign(this, init); + this.serviceContext = init?.serviceContext ?? serviceContextFromDefaults(); + this.llm = init?.llm ?? this.serviceContext.llm; + this.metadata_fields = init?.metadata_fields ?? []; + this.filters = init?.filters ?? null; + this.metadataFilterChoicePrompt = init?.metadataFilterChoicePrompt || defaultMetadataFilterChoicePrompt; } async postprocessNodes(nodes: NodeWithScore[], query: string): Promise { let filters; if (this.filters) { filters = this.filters; - console.info('Apply provided filters:', filters); + console.info('[Metadata Filter] Provided filters: ', filters); } else { const start = DateTime.now(); filters = await this.generateFilters(query); - const end = DateTime.now(); - console.info('Generate filters took:', end.diff(start).as('seconds'), 's'); - console.info('Apply generated filters:', filters); + const duration = DateTime.now().diff(start).as('milliseconds') + console.info(`[Metadata Filter] Generate filters (took: ${duration} ms): `, filters); } - console.log('Nodes before filter:', nodes.length, 'nodes'); - let filteredNodes = await this.filterNodes(nodes, filters); - console.log('Nodes after filter:', filteredNodes.length, 'nodes'); + const filteredNodes = await this.filterNodes(nodes, filters); + console.log(`[Metadata Filter] Applied provided/generated filter (before: ${nodes.length} nodes, after: ${filteredNodes.length} nodes).`); if (filteredNodes.length === 0) { - console.warn('No nodes left after filtering, fallback to using all nodes.'); + console.warn('[Metadata Filter] No nodes left after filtering, fallback to using all nodes.'); return nodes; } @@ -91,12 +100,11 @@ export class MetadataPostFilter implements BaseNodePostprocessor { async generateFilters(query: string): Promise { try { - const llm = this.serviceContext.llm; const prompt = this.metadataFilterChoicePrompt({ metadataFields: this.metadata_fields, query }); - const raw = await llm.chat({ + const raw = await this.llm.chat({ messages: [ { role: 'system',