Skip to content

Commit

Permalink
[RAG] Add example for RAG with Langchain.js (#550)
Browse files Browse the repository at this point in the history
Add a simple example of RAG using Langchain.js, following [its
cookbook](https://js.langchain.com/v0.1/docs/expression_language/cookbook/retrieval/).

We use WebLLM for both embedding and LLM, within a single engine. There
are many possible ways to achieve RAG (e.g. degree of integration with
Langchain, using WebWorker, etc.). We only provide a minimal example
here.
  • Loading branch information
CharlieFRuan committed Aug 15, 2024
1 parent c7a7285 commit b598fc1
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Note that all examples below run in-browser and use WebGPU as a backend.
- [next-simple-chat](next-simple-chat): a mininum and complete chat bot app with [Next.js](https://nextjs.org/).
- [multi-round-chat](multi-round-chat): while APIs are functional, we internally optimize so that multi round chat usage can reuse KV cache
- [text-completion](text-completion): demonstrates API `engine.completions.create()`, which is pure text completion with no conversation, as opposed to `engine.chat.completions.create()`
- [embeddings](embeddings): demonstrates API `engine.embeddings.create()`, and integration with `EmbeddingsInterface` and `MemoryVectorStore` of [Langchain.js](js.langchain.com)
- [embeddings](embeddings): demonstrates API `engine.embeddings.create()`, integration with `EmbeddingsInterface` and `MemoryVectorStore` of [Langchain.js](js.langchain.com), and RAG with Langchain.js using WebLLM for both LLM and Embedding in a single engine
- [multi-models](multi-models): demonstrates loading multiple models in a single engine concurrently

#### Advanced OpenAI API Capabilities
Expand Down
74 changes: 69 additions & 5 deletions examples/embeddings/src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ import * as webllm from "@mlc-ai/web-llm";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import type { EmbeddingsInterface } from "@langchain/core/embeddings";
import type { Document } from "@langchain/core/documents";
import { formatDocumentsAsString } from "langchain/util/document";
import { PromptTemplate } from "@langchain/core/prompts";
import {
RunnableSequence,
RunnablePassthrough,
} from "@langchain/core/runnables";

function setLabel(id: string, text: string) {
const label = document.getElementById(id);
Expand All @@ -18,12 +24,17 @@ const initProgressCallback = (report: webllm.InitProgressReport) => {
// For integration with Langchain
class WebLLMEmbeddings implements EmbeddingsInterface {
engine: webllm.MLCEngineInterface;
constructor(engine: webllm.MLCEngineInterface) {
modelId: string;
constructor(engine: webllm.MLCEngineInterface, modelId: string) {
this.engine = engine;
this.modelId = modelId;
}

async _embed(texts: string[]): Promise<number[][]> {
const reply = await this.engine.embeddings.create({ input: texts });
const reply = await this.engine.embeddings.create({
input: texts,
model: this.modelId,
});
const result: number[][] = [];
for (let i = 0; i < texts.length; i++) {
result.push(reply.data[i].embedding);
Expand Down Expand Up @@ -82,7 +93,7 @@ async function webllmAPI() {

// Calculate similarity (we use langchain here, but any method works)
const vectorStore = await MemoryVectorStore.fromExistingIndex(
new WebLLMEmbeddings(engine),
new WebLLMEmbeddings(engine, selectedModel),
);
// See score
for (let i = 0; i < queries_og.length; i++) {
Expand Down Expand Up @@ -113,7 +124,7 @@ async function langchainAPI() {
);

const vectorStore = await MemoryVectorStore.fromExistingIndex(
new WebLLMEmbeddings(engine),
new WebLLMEmbeddings(engine, selectedModel),
);
const document0: Document = {
pageContent: documents[0],
Expand Down Expand Up @@ -142,6 +153,59 @@ async function langchainAPI() {
}
}

// RAG with Langchain.js using WebLLM for both LLM and Embedding in a single engine
// Followed https://js.langchain.com/v0.1/docs/expression_language/cookbook/retrieval/
// There are many possible ways to achieve RAG (e.g. degree of integration with Langchain,
// using WebWorker, etc.). We provide a minimal example here.
async function simpleRAG() {
// 0. Load both embedding model and LLM to a single WebLLM Engine
const embeddingModelId = "snowflake-arctic-embed-m-q0f32-MLC-b4";
const llmModelId = "gemma-2-2b-it-q4f32_1-MLC-1k";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
[embeddingModelId, llmModelId],
{
initProgressCallback: initProgressCallback,
logLevel: "INFO", // specify the log level
},
);

const vectorStore = await MemoryVectorStore.fromTexts(
["mitochondria is the powerhouse of the cell"],
[{ id: 1 }],
new WebLLMEmbeddings(engine, embeddingModelId),
);
const retriever = vectorStore.asRetriever();

const prompt =
PromptTemplate.fromTemplate(`Answer the question based only on the following context:
{context}
Question: {question}`);

const chain = RunnableSequence.from([
{
context: retriever.pipe(formatDocumentsAsString),
question: new RunnablePassthrough(),
},
prompt,
]);

const formattedPrompt = (
await chain.invoke("What is the powerhouse of the cell?")
).toString();
const reply = await engine.chat.completions.create({
messages: [{ role: "user", content: formattedPrompt }],
model: llmModelId,
});

console.log(reply.choices[0].message.content);

/*
"The powerhouse of the cell is the mitochondria."
*/
}

// Select one to run
webllmAPI();
// webllmAPI();
// langchainAPI();
simpleRAG();

0 comments on commit b598fc1

Please sign in to comment.