From 76cf2387031dd573e70d639e781e823dd510001d Mon Sep 17 00:00:00 2001 From: Paniz Date: Mon, 15 Jul 2024 00:18:28 -0400 Subject: [PATCH] test --- .../SafeTensorsDenseVectorCollection.java | 206 +++++++++++++ .../io/anserini/index/AbstractIndexer.java | 4 +- ...feTensorsDenseVectorDocumentGenerator.java | 273 ------------------ ...feTensorsDenseVectorDocumentGenerator.java | 59 ++++ 4 files changed, 266 insertions(+), 276 deletions(-) create mode 100644 src/main/java/io/anserini/collection/SafeTensorsDenseVectorCollection.java delete mode 100644 src/main/java/io/anserini/index/generator/HnswJsonWithSafeTensorsDenseVectorDocumentGenerator.java create mode 100644 src/main/java/io/anserini/index/generator/SafeTensorsDenseVectorDocumentGenerator.java diff --git a/src/main/java/io/anserini/collection/SafeTensorsDenseVectorCollection.java b/src/main/java/io/anserini/collection/SafeTensorsDenseVectorCollection.java new file mode 100644 index 0000000000..04b93a2d5e --- /dev/null +++ b/src/main/java/io/anserini/collection/SafeTensorsDenseVectorCollection.java @@ -0,0 +1,206 @@ +package io.anserini.collection; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.io.BufferedReader; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.stream.Stream; + +public class SafeTensorsDenseVectorCollection extends DocumentCollection { + private static final Logger LOG = LogManager.getLogger(SafeTensorsDenseVectorCollection.class); + private String vectorsFilePath; + private String docidsFilePath; + public double[][] vectors; + public String[] docids; + + public SafeTensorsDenseVectorCollection(Path path) throws IOException { + this.path = path; + generateFilePaths(path); + readData(); + } + + @Override + public FileSegment createFileSegment(Path p) throws IOException { + return new SafeTensorsDenseVectorCollection.Segment(p, vectors, docids); + } + + @Override + public FileSegment createFileSegment(BufferedReader bufferedReader) throws IOException { + throw new UnsupportedOperationException("BufferedReader is not supported for SafeTensorsDenseVectorCollection."); + } + + private void generateFilePaths(Path inputFolder) throws IOException { + String inputFileName; + try (Stream files = Files.list(inputFolder)) { + inputFileName = files + .filter(file -> file.toString().endsWith(".safetensors")) + .map(file -> file.getFileName().toString()) + .findFirst() + .orElseThrow(() -> new IOException("No valid input file found in the directory")); + } + + Path parent = inputFolder.getParent(); + String baseName = inputFileName.replace(".safetensors", ""); + vectorsFilePath = Paths.get(parent.toString(), baseName + "_vectors.safetensors").toString(); + docidsFilePath = Paths.get(parent.toString(), baseName + "_docids.safetensors").toString(); + } + + private void readData() throws IOException { + vectors = readVectors(vectorsFilePath); + docids = readDocidAsciiValues(docidsFilePath); + } + + private double[][] readVectors(String filePath) throws IOException { + byte[] data = Files.readAllBytes(Paths.get(filePath)); + Map header = parseHeader(data); + return extractVectors(data, header); + } + + private String[] readDocidAsciiValues(String filePath) throws IOException { + byte[] data = Files.readAllBytes(Paths.get(filePath)); + Map header = parseHeader(data); + return extractDocids(data, header); + } + + private Map parseHeader(byte[] data) throws IOException { + ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN); + long headerSize = buffer.getLong(); + byte[] headerBytes = new byte[(int) headerSize]; + buffer.get(headerBytes); + String headerJson = new String(headerBytes, StandardCharsets.UTF_8).trim(); + ObjectMapper objectMapper = new ObjectMapper(); + return objectMapper.readValue(headerJson, Map.class); + } + + private double[][] extractVectors(byte[] data, Map header) { + Map vectorsInfo = (Map) header.get("vectors"); + String dtype = (String) vectorsInfo.get("dtype"); + + List shapeList = (List) vectorsInfo.get("shape"); + int rows = shapeList.get(0); + int cols = shapeList.get(1); + List dataOffsets = (List) vectorsInfo.get("data_offsets"); + long begin = dataOffsets.get(0).longValue(); + long end = dataOffsets.get(1).longValue(); + + ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN); + buffer.position((int) (begin + buffer.getLong(0) + 8)); + + double[][] vectors = new double[rows][cols]; + if (dtype.equals("F64")) { + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + vectors[i][j] = buffer.getDouble(); + } + } + } else { + throw new UnsupportedOperationException("Unsupported data type: " + dtype); + } + + return vectors; + } + + private String[] extractDocids(byte[] data, Map header) { + Map docidsInfo = (Map) header.get("docids"); + String dtype = (String) docidsInfo.get("dtype"); + + List shapeList = (List) docidsInfo.get("shape"); + int length = shapeList.get(0); + int maxCols = shapeList.get(1); + + List dataOffsets = (List) docidsInfo.get("data_offsets"); + long begin = dataOffsets.get(0).longValue(); + long end = dataOffsets.get(1).longValue(); + + ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN); + buffer.position((int) (begin + buffer.getLong(0) + 8)); + + String[] docids = new String[length]; + StringBuilder sb = new StringBuilder(); + if (dtype.equals("I64")) { + for (int i = 0; i < length; i++) { + sb.setLength(0); + for (int j = 0; j < maxCols; j++) { + char c = (char) buffer.getLong(); + if (c != 0) + sb.append(c); + } + docids[i] = sb.toString(); + } + } else { + throw new UnsupportedOperationException("Unsupported data type: " + dtype); + } + + return docids; + } + + public static class Segment extends FileSegment { + private double[][] vectors; + private String[] docids; + private int currentIndex; + + public Segment(Path path, double[][] vectors, String[] docids) throws IOException { + super(path); + this.vectors = vectors; + this.docids = docids; + this.currentIndex = 0; + } + + @Override + protected void readNext() throws IOException, NoSuchElementException { + if (currentIndex >= docids.length) { + atEOF = true; + throw new NoSuchElementException("End of file reached"); + } + + String id = docids[currentIndex]; + double[] vector = vectors[currentIndex]; + bufferedRecord = new SafeTensorsDenseVectorCollection.Document(id, vector, ""); + currentIndex++; + } + } + + public static class Document implements SourceDocument { + private final String id; + private final double[] vector; + private final String raw; + + public Document(String id, double[] vector, String raw) { + this.id = id; + this.vector = vector; + this.raw = raw; + } + + @Override + public String id() { + return id; + } + + @Override + public String contents() { + return Arrays.toString(vector); + } + + @Override + public String raw() { + return raw; + } + + @Override + public boolean indexable() { + return true; + } + } +} diff --git a/src/main/java/io/anserini/index/AbstractIndexer.java b/src/main/java/io/anserini/index/AbstractIndexer.java index a0f867b34f..05c27d6fcf 100644 --- a/src/main/java/io/anserini/index/AbstractIndexer.java +++ b/src/main/java/io/anserini/index/AbstractIndexer.java @@ -324,9 +324,7 @@ protected void processSegments(ThreadPoolExecutor executor, List segmentPa // Each thread gets its own document generator, so we don't need to make any assumptions about its thread safety. @SuppressWarnings("unchecked") LuceneDocumentGenerator generator = (LuceneDocumentGenerator) - // generatorClass.getDeclaredConstructor((Class []) null).newInstance(); - generatorClass.getDeclaredConstructor(Args.class).newInstance(args); - + generatorClass.getDeclaredConstructor((Class []) null).newInstance(); executor.execute(new IndexerThread(segmentPath, generator)); } catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { diff --git a/src/main/java/io/anserini/index/generator/HnswJsonWithSafeTensorsDenseVectorDocumentGenerator.java b/src/main/java/io/anserini/index/generator/HnswJsonWithSafeTensorsDenseVectorDocumentGenerator.java deleted file mode 100644 index 85cdc8a22b..0000000000 --- a/src/main/java/io/anserini/index/generator/HnswJsonWithSafeTensorsDenseVectorDocumentGenerator.java +++ /dev/null @@ -1,273 +0,0 @@ -package io.anserini.index.generator; - -import io.anserini.collection.SourceDocument; -import io.anserini.index.IndexHnswDenseVectors; -import io.anserini.index.AbstractIndexer; -import io.anserini.index.Constants; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.document.BinaryDocValuesField; -import org.apache.lucene.document.Document; -import org.apache.lucene.document.Field; -import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.document.StringField; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.BytesRef; -import com.fasterxml.jackson.databind.ObjectMapper; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Map; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Map; -import java.util.List; -import java.util.stream.Stream; - -public class HnswJsonWithSafeTensorsDenseVectorDocumentGenerator - implements LuceneDocumentGenerator { - private static final Logger LOG = LogManager.getLogger(HnswJsonWithSafeTensorsDenseVectorDocumentGenerator.class); - protected AbstractIndexer.Args args; // Use the base class type for flexibility - private HashSet allowedFileSuffix; - - public HnswJsonWithSafeTensorsDenseVectorDocumentGenerator(AbstractIndexer.Args args) { - this.args = args; - initializeArgs(args); // Pass the args to initializeArgs - this.allowedFileSuffix = new HashSet<>(Arrays.asList(".json", ".jsonl", ".gz")); - } - - private void initializeArgs(AbstractIndexer.Args args) { - if (args instanceof IndexHnswDenseVectors.Args) { - // Specific initialization for HNSW args - IndexHnswDenseVectors.Args hnswArgs = (IndexHnswDenseVectors.Args) args; - LOG.info("Initialized with HNSW specific settings: M=" + hnswArgs.M); - } else { - // Generic or other specific initializations - LOG.info("Initialized with generic settings"); - } - - if (args.input == null || args.input.isEmpty()) { - LOG.error("Input path is not provided."); - throw new IllegalArgumentException("Input path is not provided."); - } - } - - @Override - public Document createDocument(T src) throws InvalidDocumentException { - try { - LOG.info("Input path for createDocument: " + this.args.input); - - if (this.args.input == null) { - LOG.error("Input path is null"); - throw new InvalidDocumentException(); - } - - Path inputFolder = Paths.get(this.args.input); - - FilePaths filePaths = generateFilePaths(inputFolder); - - if (filePaths == null) { - LOG.error("Error generating file paths"); - throw new InvalidDocumentException(); - } - - LOG.info("Generated file paths: "); - LOG.info(" - Vectors: " + filePaths.vectorsFilePath); - LOG.info(" - Docids: " + filePaths.docidsFilePath); - - if (filePaths.vectorsFilePath == null || filePaths.docidsFilePath == null) { - LOG.error("Error generating file paths"); - throw new InvalidDocumentException(); - } - - // Read vectors and docids from safetensors - double[][] vectors = readVectors(filePaths.vectorsFilePath); - String[] docids = readDocidAsciiValues(filePaths.docidsFilePath); - - String id = src.id(); - LOG.info("Processing document ID: " + id); - int index = Arrays.asList(docids).indexOf(id); - - if (index == -1) { - LOG.error("Error finding index for document ID: " + id); - LOG.error("Document ID ASCII: " + Arrays.toString(id.chars().toArray())); - LOG.error("Available IDs ASCII: " + Arrays.deepToString(docids)); - throw new InvalidDocumentException(); - } - - float[] contents = new float[vectors[index].length]; - for (int i = 0; i < contents.length; i++) { - contents[i] = (float) vectors[index][i]; - } - - final Document document = new Document(); - document.add(new StringField(Constants.ID, id, Field.Store.YES)); - document.add(new BinaryDocValuesField(Constants.ID, new BytesRef(id))); - document.add(new KnnFloatVectorField(Constants.VECTOR, contents, VectorSimilarityFunction.DOT_PRODUCT)); - return document; - } catch (Exception e) { - LOG.error("Error creating document", e); - LOG.error("trace: " + Arrays.toString(e.getStackTrace())); - - LOG.error("Document ID: " + src.id()); - LOG.error("Document contents: " + src.contents()); - LOG.error("Paths: " + this.args.input); - - throw new InvalidDocumentException(); - } - } - - public FilePaths generateFilePaths(Path inputFolder) throws IOException { - String inputFileName; - try (Stream files = Files.list(inputFolder)) { - inputFileName = files - .filter(file -> allowedFileSuffix.stream().anyMatch(suffix -> file.getFileName().toString().endsWith(suffix))) - .map(file -> file.getFileName().toString()) - .findFirst() - .orElseThrow(() -> new IOException("No valid input file found in the directory")); - } - - Path grandParent = inputFolder.getParent().getParent(); - Path parent = inputFolder.getParent(); - Path safetensorsFolder = Paths.get(grandParent.toString() + "/" + parent.getFileName().toString() + ".safetensors", - inputFolder.getFileName().toString()); - - String baseName = inputFileName.replace(".jsonl", "").replace(".json", "").replace(".gz", ""); - String vectorsFilePath = Paths.get(safetensorsFolder.toString(), baseName + "_vectors.safetensors").toString(); - String docidsFilePath = Paths.get(safetensorsFolder.toString(), baseName + "_docids.safetensors").toString(); - - return new FilePaths(vectorsFilePath, docidsFilePath); - } - - public static class FilePaths { - public String vectorsFilePath; - public String docidsFilePath; - - public FilePaths(String vectorsFilePath, String docidsFilePath) { - this.vectorsFilePath = vectorsFilePath; - this.docidsFilePath = docidsFilePath; - } - } - - private double[][] readVectors(String filePath) throws IOException { - byte[] data = Files.readAllBytes(Paths.get(filePath)); - Map header = parseHeader(data); - return extractVectors(data, header); - } - - private String[] readDocidAsciiValues(String filePath) throws IOException { - byte[] data = Files.readAllBytes(Paths.get(filePath)); - Map header = parseHeader(data); - return extractDocids(data, header); - } - - @SuppressWarnings("unchecked") - private static Map parseHeader(byte[] data) throws IOException { - ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN); - long headerSize = buffer.getLong(); - byte[] headerBytes = new byte[(int) headerSize]; - buffer.get(headerBytes); - String headerJson = new String(headerBytes, StandardCharsets.UTF_8).trim(); - System.out.println("Header JSON: " + headerJson); - ObjectMapper objectMapper = new ObjectMapper(); - return objectMapper.readValue(headerJson, Map.class); - } - - private static double[][] extractVectors(byte[] data, Map header) { - @SuppressWarnings("unchecked") - Map vectorsInfo = (Map) header.get("vectors"); - String dtype = (String) vectorsInfo.get("dtype"); - - @SuppressWarnings("unchecked") - List shapeList = (List) vectorsInfo.get("shape"); - int rows = shapeList.get(0); - int cols = shapeList.get(1); - @SuppressWarnings("unchecked") - List dataOffsets = (List) vectorsInfo.get("data_offsets"); - long begin = dataOffsets.get(0).longValue(); - long end = dataOffsets.get(1).longValue(); - - System.out.println("Vectors shape: " + rows + "x" + cols); - System.out.println("Data offsets: " + begin + " to " + end); - System.out.println("Data type: " + dtype); - - ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN); - // Correctly position the buffer to start reading after the header - buffer.position((int) (begin + buffer.getLong(0) + 8)); - - double[][] vectors = new double[rows][cols]; - if (dtype.equals("F64")) { - for (int i = 0; i < rows; i++) { - for (int j = 0; j < cols; j++) { - vectors[i][j] = buffer.getDouble(); - } - } - } else { - throw new UnsupportedOperationException("Unsupported data type: " + dtype); - } - - // Log the first few rows and columns to verify the content - System.out.println("First few vectors:"); - for (int i = 0; i < Math.min(5, rows); i++) { - for (int j = 0; j < Math.min(10, cols); j++) { - System.out.print(vectors[i][j] + " "); - } - System.out.println(); - } - - return vectors; - } - - @SuppressWarnings("unchecked") - private static String[] extractDocids(byte[] data, Map header) { - Map docidsInfo = (Map) header.get("docids"); - String dtype = (String) docidsInfo.get("dtype"); - - List shapeList = (List) docidsInfo.get("shape"); - int length = shapeList.get(0); - int maxCols = shapeList.get(1); - - List dataOffsets = (List) docidsInfo.get("data_offsets"); - long begin = dataOffsets.get(0).longValue(); - long end = dataOffsets.get(1).longValue(); - - System.out.println("Docids shape: " + length + "x" + maxCols); - System.out.println("Data offsets: " + begin + " to " + end); - System.out.println("Data type: " + dtype); - - ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN); - // Correctly position the buffer to start reading after the header - buffer.position((int) (begin + buffer.getLong(0) + 8)); - - String[] docids = new String[length]; - StringBuilder sb = new StringBuilder(); - if (dtype.equals("I64")) { - for (int i = 0; i < length; i++) { - sb.setLength(0); - for (int j = 0; j < maxCols; j++) { - char c = (char) buffer.getLong(); - if (c != 0) - sb.append(c); - } - docids[i] = sb.toString(); - } - } else { - throw new UnsupportedOperationException("Unsupported data type: " + dtype); - } - - // Log the first few docid indices to verify the content - System.out.println("First few docids:"); - for (int i = 0; i < Math.min(10, docids.length); i++) { - System.out.println(docids[i]); - } - - return docids; - } -} diff --git a/src/main/java/io/anserini/index/generator/SafeTensorsDenseVectorDocumentGenerator.java b/src/main/java/io/anserini/index/generator/SafeTensorsDenseVectorDocumentGenerator.java new file mode 100644 index 0000000000..c09c23804c --- /dev/null +++ b/src/main/java/io/anserini/index/generator/SafeTensorsDenseVectorDocumentGenerator.java @@ -0,0 +1,59 @@ +package io.anserini.index.generator; + +import io.anserini.collection.SourceDocument; +import io.anserini.collection.SafeTensorsDenseVectorCollection; +import io.anserini.index.Constants; +import org.apache.lucene.document.BinaryDocValuesField; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.BytesRef; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.Arrays; + +public class SafeTensorsDenseVectorDocumentGenerator implements LuceneDocumentGenerator { + private static final Logger LOG = LogManager.getLogger(SafeTensorsDenseVectorDocumentGenerator.class); + private SafeTensorsDenseVectorCollection collection; + + public SafeTensorsDenseVectorDocumentGenerator(SafeTensorsDenseVectorCollection collection) { + this.collection = collection; + } + + @Override + public Document createDocument(T src) throws InvalidDocumentException { + try { + LOG.info("Processing document ID: " + src.id()); + float[] contents = getVectorForDocId(src.id()); + + if (contents == null) { + throw new InvalidDocumentException(); + } + + final Document document = new Document(); + document.add(new StringField(Constants.ID, src.id(), Field.Store.YES)); + document.add(new BinaryDocValuesField(Constants.ID, new BytesRef(src.id()))); + document.add(new KnnFloatVectorField(Constants.VECTOR, contents, VectorSimilarityFunction.DOT_PRODUCT)); + + return document; + } catch (Exception e) { + LOG.error("Error creating document", e); + throw new InvalidDocumentException(); + } + } + + private float[] getVectorForDocId(String docId) { + int index = Arrays.asList(collection.docids).indexOf(docId); + if (index == -1) { + return null; + } + float[] vector = new float[collection.vectors[index].length]; + for (int i = 0; i < vector.length; i++) { + vector[i] = (float) collection.vectors[index][i]; + } + return vector; + } +}