From 4ddc64066aedecda9a772414b4b249589ffc1102 Mon Sep 17 00:00:00 2001 From: Jie Min <66545235+Stefan824@users.noreply.github.com> Date: Sun, 8 Sep 2024 12:40:11 -0400 Subject: [PATCH] Add rank fusion - initial implementation (#2590) --- pom.xml | 12 + .../java/io/anserini/fusion/FuseTrecRuns.java | 131 +++++++++ .../io/anserini/fusion/RescoreMethod.java | 23 ++ src/main/java/io/anserini/fusion/TrecRun.java | 264 ++++++++++++++++++ .../java/io/anserini/fusion/TrecRunFuser.java | 153 ++++++++++ 5 files changed, 583 insertions(+) create mode 100644 src/main/java/io/anserini/fusion/FuseTrecRuns.java create mode 100644 src/main/java/io/anserini/fusion/RescoreMethod.java create mode 100644 src/main/java/io/anserini/fusion/TrecRun.java create mode 100644 src/main/java/io/anserini/fusion/TrecRunFuser.java diff --git a/pom.xml b/pom.xml index df84a30a6..86551972e 100644 --- a/pom.xml +++ b/pom.xml @@ -536,5 +536,17 @@ + + junit + junit + 4.13.2 + test + + + org.junit.jupiter + junit-jupiter-engine + 5.8.2 + test + diff --git a/src/main/java/io/anserini/fusion/FuseTrecRuns.java b/src/main/java/io/anserini/fusion/FuseTrecRuns.java new file mode 100644 index 000000000..5f2d45a11 --- /dev/null +++ b/src/main/java/io/anserini/fusion/FuseTrecRuns.java @@ -0,0 +1,131 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.fusion; + +import org.kohsuke.args4j.CmdLineException; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.Option; +import org.kohsuke.args4j.ParserProperties; +import org.kohsuke.args4j.spi.StringArrayOptionHandler; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.io.IOException; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * Main entry point for Fusion. + */ +public class FuseTrecRuns { + private static final Logger LOG = LogManager.getLogger(FuseTrecRuns.class); + + public static class Args extends TrecRunFuser.Args { + @Option(name = "-options", required = false, usage = "Print information about options.") + public Boolean options = false; + + @Option(name = "-runs", handler = StringArrayOptionHandler.class, metaVar = "[file]", required = true, + usage = "Path to both run files to fuse") + public String[] runs; + + @Option (name = "-resort", required = false, metaVar = "[flag]", usage="We Resort the Trec run files or not") + public boolean resort = false; + } + + private final Args args; + private final TrecRunFuser fuser; + private final List runs = new ArrayList(); + + public FuseTrecRuns(Args args) throws IOException { + this.args = args; + this.fuser = new TrecRunFuser(args); + + LOG.info(String.format("============ Initializing %s ============", FuseTrecRuns.class.getSimpleName())); + LOG.info("Runs: " + Arrays.toString(args.runs)); + LOG.info("Run tag: " + args.runtag); + LOG.info("Fusion method: " + args.method); + LOG.info("Reciprocal Rank Fusion K value (rrf_k): " + args.rrf_k); + LOG.info("Alpha value for interpolation: " + args.alpha); + LOG.info("Max documents to output (k): " + args.k); + LOG.info("Pool depth: " + args.depth); + LOG.info("Resort TREC run files: " + args.resort); + + try { + // Ensure positive depth and k values + if (args.depth <= 0) { + throw new IllegalArgumentException("Option depth must be greater than 0"); + } + if (args.k <= 0) { + throw new IllegalArgumentException("Option k must be greater than 0"); + } + } catch (Exception e) { + throw new IllegalArgumentException(String.format("Error: %s. Please check the provided arguments. Use the \"-options\" flag to print out detailed information about available options and their usage.\n", + e.getMessage())); + } + + for (String runFile : args.runs) { + try { + Path path = Paths.get(runFile); + TrecRun run = new TrecRun(path, args.resort); + runs.add(run); + } catch (Exception e) { + throw new IllegalArgumentException(String.format("Error: %s. Please check the provided arguments. Use the \"-options\" flag to print out detailed information about available options and their usage.\n", + e.getMessage())); + } + } + } + + public void run() throws IOException { + LOG.info("============ Launching Fusion ============"); + fuser.fuse(runs); + } + + public static void main(String[] args) throws Exception { + Args fuseArgs = new Args(); + CmdLineParser parser = new CmdLineParser(fuseArgs, ParserProperties.defaults().withUsageWidth(120)); + + try { + parser.parseArgument(args); + } catch (CmdLineException e) { + if (fuseArgs.options) { + System.err.printf("Options for %s:\n\n", FuseTrecRuns.class.getSimpleName()); + parser.printUsage(System.err); + ArrayList required = new ArrayList<>(); + parser.getOptions().forEach(option -> { + if (option.option.required()) { + required.add(option.option.toString()); + } + }); + System.err.printf("\nRequired options are %s\n", required); + } else { + System.err.printf("Error: %s. For help, use \"-options\" to print out information about options.\n", + e.getMessage()); + } + return; + } + + try { + FuseTrecRuns fuser = new FuseTrecRuns(fuseArgs); + fuser.run(); + } catch (Exception e) { + System.err.println(e.getMessage()); + } + } +} diff --git a/src/main/java/io/anserini/fusion/RescoreMethod.java b/src/main/java/io/anserini/fusion/RescoreMethod.java new file mode 100644 index 000000000..e07e9f082 --- /dev/null +++ b/src/main/java/io/anserini/fusion/RescoreMethod.java @@ -0,0 +1,23 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.fusion; + +public enum RescoreMethod { + RRF, + SCALE, + NORMALIZE; +} diff --git a/src/main/java/io/anserini/fusion/TrecRun.java b/src/main/java/io/anserini/fusion/TrecRun.java new file mode 100644 index 000000000..14c63c7c2 --- /dev/null +++ b/src/main/java/io/anserini/fusion/TrecRun.java @@ -0,0 +1,264 @@ +/* +* Anserini: A Lucene toolkit for reproducible information retrieval research +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package io.anserini.fusion; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.commons.io.FileUtils; + +/** + * Wrapper class for a TREC run. +*/ +public class TrecRun { + // Enum representing the columns in the TREC run file + public enum Column { + TOPIC, Q0, DOCID, RANK, SCORE, TAG + } + + private List> runData; + private Path filepath = null; + private Boolean reSort = false; + + // Constructor without reSort parameter + public TrecRun(Path filepath) throws IOException { + this(filepath, false); + } + + // Constructor with reSort parameter + public TrecRun(Path filepath, Boolean reSort) throws IOException { + this.resetData(); + this.filepath = filepath; + this.reSort = reSort; + this.readRun(filepath); + } + + // Constructor without parameters + public TrecRun() { + this.resetData(); + } + + private void resetData() { + runData = new ArrayList<>(); + } + + /** + * Reads a TREC run file and loads its data into the runData list. + * + * @param filepath Path to the TREC run file. + * @throws IOException If the file cannot be read. + */ + public void readRun(Path filepath) throws IOException { + try (BufferedReader br = new BufferedReader(new FileReader(filepath.toFile()))) { + String line; + while ((line = br.readLine()) != null) { + String[] data = line.split("\\s+"); + Map record = new EnumMap<>(Column.class); + + // Populate the record map with the parsed data + record.put(Column.TOPIC, data[0]); + record.put(Column.Q0, data[1]); + record.put(Column.DOCID, data[2]); + record.put(Column.TAG, data[5]); + + // Parse RANK as integer + int rankInt = Integer.parseInt(data[3]); + record.put(Column.RANK, rankInt); + + // Parse SCORE as double + double scoreFloat = Double.parseDouble(data[4]); + record.put(Column.SCORE, scoreFloat); + + // Add the record to runData + runData.add(record); + } + } + + if (reSort) { + runData.sort((record1, record2) -> { + int topicComparison = ((String)record1.get(Column.TOPIC)).compareTo((String)(record2.get(Column.TOPIC))); + if (topicComparison != 0) { + return topicComparison; + } + return Double.compare((Double)(record2.get(Column.SCORE)), (Double)record1.get(Column.SCORE)); + }); + String currentTopic = ""; + int rank = 1; + for (Map record : runData) { + String topic = (String) record.get(Column.TOPIC); + if (!topic.equals(currentTopic)) { + currentTopic = topic; + rank = 1; + } + record.put(Column.RANK, rank); + rank++; + } + } + } + + public Set getTopics() { + return runData.stream().map(record -> (String) record.get(Column.TOPIC)).collect(Collectors.toSet()); + } + + public TrecRun cloneRun() throws IOException { + TrecRun clone = new TrecRun(); + clone.runData = new ArrayList<>(this.runData); + clone.filepath = this.filepath; + clone.reSort = this.reSort; + return clone; + } + + /** + * Saves the TREC run data to a text file in the TREC run format. + * + * @param outputPath Path to the output file. + * @param tag Tag to be added to each record in the TREC run file. If null, the existing tags are retained. + * @throws IOException If an I/O error occurs while writing to the file. + * @throws IllegalStateException If the runData list is empty. + */ + public void saveToTxt(Path outputPath, String tag) throws IOException { + if (runData.isEmpty()) { + throw new IllegalStateException("Nothing to save. TrecRun is empty"); + } + if (tag != null) { + runData.forEach(record -> record.put(Column.TAG, tag)); + } + runData.sort(Comparator.comparing((Map r) -> Integer.parseInt((String) r.get(Column.TOPIC))) + .thenComparing(r -> (Double) r.get(Column.SCORE), Comparator.reverseOrder())); + FileUtils.writeLines(outputPath.toFile(), runData.stream() + .map(record -> record.entrySet().stream() + .map(entry -> { + if (entry.getKey() == Column.SCORE) { + return String.format("%.6f", entry.getValue()); + } else { + return entry.getValue().toString(); + } + }) + .collect(Collectors.joining(" "))) + .collect(Collectors.toList())); + } + + public List> getDocsByTopic(String topic, int maxDocs) { + return runData.stream() + .filter(record -> record.get(Column.TOPIC).equals(topic)) // Filter by topic + .limit(maxDocs > 0 ? maxDocs : Integer.MAX_VALUE) // Limit the number of docs if maxDocs > 0 + .collect(Collectors.toList()); // Collect as List> + } + + public TrecRun rescore(RescoreMethod method, int rrfK, double scale) { + switch (method) { + case RRF -> rescoreRRF(rrfK); + case SCALE -> rescoreScale(scale); + case NORMALIZE -> normalizeScores(); + default -> throw new UnsupportedOperationException("Unknown rescore method: " + method); + } + return this; + } + + private void rescoreRRF(int rrfK) { + runData.forEach(record -> { + double score = 1.0 / (rrfK + (Integer)(record.get(Column.RANK))); + record.put(Column.SCORE, score); + }); + } + + private void rescoreScale(double scale) { + runData.forEach(record -> { + double score = (Double) record.get(Column.SCORE) * scale; + record.put(Column.SCORE, score); + }); + } + + private void normalizeScores() { + for (String topic : getTopics()) { + List> topicRecords = runData.stream() + .filter(record -> record.get(Column.TOPIC).equals(topic)) + .collect(Collectors.toList()); + + double minScore = topicRecords.stream() + .mapToDouble(record -> (Double) record.get(Column.SCORE)) + .min().orElse(0.0); + double maxScore = topicRecords.stream() + .mapToDouble(record -> (Double) record.get(Column.SCORE)) + .max().orElse(1.0); + + for (Map record : topicRecords) { + double normalizedScore = ((Double) record.get(Column.SCORE) - minScore) / (maxScore - minScore); + record.put(Column.SCORE, normalizedScore); + } + } + } + + /** + * Merges multiple TrecRun instances into a single TrecRun instance. + * The merged run will contain the top documents for each topic, with scores summed across the input runs. + * + * @param runs List of TrecRun instances to merge. + * @param depth Maximum number of documents to consider from each run for each topic (null for no limit). + * @param k Maximum number of top documents to include in the merged run for each topic (null for no limit). + * @return A new TrecRun instance containing the merged results. + * @throws IllegalArgumentException if less than 2 runs are provided. + */ + public static TrecRun merge(List runs, Integer depth, Integer k) { + if (runs.size() < 2) { + throw new IllegalArgumentException("Merge requires at least 2 runs."); + } + + TrecRun mergedRun = new TrecRun(); + + Set topics = runs.stream().flatMap(run -> run.getTopics().stream()).collect(Collectors.toSet()); + + topics.forEach(topic -> { + Map docScores = new HashMap<>(); + for (TrecRun run : runs) { + run.getDocsByTopic(topic, depth != null ? depth : Integer.MAX_VALUE).forEach(record -> { + String docId = (String) record.get(Column.DOCID); + double score = (Double) record.get(Column.SCORE); + docScores.put(docId, docScores.getOrDefault(docId, 0.0) + score); + }); + } + List> sortedDocScores = docScores.entrySet().stream() + .sorted(Map.Entry.comparingByValue().reversed()) + .limit(k != null ? k : Integer.MAX_VALUE) + .collect(Collectors.toList()); + + for (int rank = 0; rank < sortedDocScores.size(); rank++) { + Map.Entry entry = sortedDocScores.get(rank); + Map record = new EnumMap<>(Column.class); + record.put(Column.TOPIC, topic); + record.put(Column.Q0, "Q0"); + record.put(Column.DOCID, entry.getKey()); + record.put(Column.RANK, rank + 1); + record.put(Column.SCORE, entry.getValue()); + record.put(Column.TAG, "merge_sum"); + mergedRun.runData.add(record); + } + }); + + return mergedRun; + } +} \ No newline at end of file diff --git a/src/main/java/io/anserini/fusion/TrecRunFuser.java b/src/main/java/io/anserini/fusion/TrecRunFuser.java new file mode 100644 index 000000000..4323e8b10 --- /dev/null +++ b/src/main/java/io/anserini/fusion/TrecRunFuser.java @@ -0,0 +1,153 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.fusion; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; + +import org.kohsuke.args4j.Option; + +/** + * Main logic class for Fusion + */ +public class TrecRunFuser { + private final Args args; + + private static final String METHOD_RRF = "rrf"; + private static final String METHOD_INTERPOLATION = "interpolation"; + private static final String METHOD_AVERAGE = "average"; + + public static class Args { + @Option(name = "-output", metaVar = "[output]", required = true, usage = "Path to save the output") + public String output; + + @Option(name = "-runtag", metaVar = "[runtag]", required = false, usage = "Run tag for the fusion") + public String runtag = "anserini.fusion"; + + @Option(name = "-method", metaVar = "[method]", required = false, usage = "Specify fusion method") + public String method = "rrf"; + + @Option(name = "-rrf_k", metaVar = "[number]", required = false, usage = "Parameter k needed for reciprocal rank fusion.") + public int rrf_k = 60; + + @Option(name = "-alpha", metaVar = "[value]", required = false, usage = "Alpha value used for interpolation.") + public double alpha = 0.5; + + @Option(name = "-k", metaVar = "[number]", required = false, usage = "number of documents to output for topic") + public int k = 1000; + + @Option(name = "-depth", metaVar = "[number]", required = false, usage = "Pool depth per topic.") + public int depth = 1000; + } + + public TrecRunFuser(Args args) { + this.args = args; + } + + /** + * Perform fusion by averaging on a list of TrecRun objects. + * + * @param runs List of TrecRun objects. + * @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered. + * @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked. + * @return Output TrecRun that combines input runs via averaging. + */ + public static TrecRun average(List runs, int depth, int k) { + + for (TrecRun run : runs) { + run.rescore(RescoreMethod.SCALE, 0, (1/(double)runs.size())); + } + + return TrecRun.merge(runs, depth, k); + } + + /** + * Perform reciprocal rank fusion on a list of TrecRun objects. Implementation follows Cormack et al. + * (SIGIR 2009) paper titled "Reciprocal Rank Fusion Outperforms Condorcet and Individual Rank Learning Methods." + * + * @param runs List of TrecRun objects. + * @param rrf_k Parameter to avoid vanishing importance of lower-ranked documents. Note that this is different from the *k* in top *k* retrieval; set to 60 by default, per Cormack et al. + * @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered. + * @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked. + * @return Output TrecRun that combines input runs via reciprocal rank fusion. + */ + public static TrecRun reciprocalRankFusion(List runs, int rrf_k, int depth, int k) { + + for (TrecRun run : runs) { + run.rescore(RescoreMethod.RRF, rrf_k, 0); + } + + return TrecRun.merge(runs, depth, k); + } + + /** + * Perform fusion by interpolation on a list of exactly two TrecRun objects. + * new_score = first_run_score * alpha + (1 - alpha) * second_run_score. + * + * @param runs List of TrecRun objects. Exactly two runs. + * @param alpha Parameter alpha will be applied on the first run and (1 - alpha) will be applied on the second run. + * @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered. + * @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked. + * @return Output TrecRun that combines input runs via interpolation. + */ + public static TrecRun interpolation(List runs, double alpha, int depth, int k) { + // Ensure exactly 2 runs are provided, as interpolation requires 2 runs + if (runs.size() != 2) { + throw new IllegalArgumentException("Interpolation requires exactly 2 runs"); + } + + runs.get(0).rescore(RescoreMethod.SCALE, 0, alpha); + runs.get(1).rescore(RescoreMethod.SCALE, 0, 1 - alpha); + + return TrecRun.merge(runs, depth, k); + } + + private void saveToTxt(TrecRun fusedRun) throws IOException { + Path outputPath = Paths.get(args.output); + fusedRun.saveToTxt(outputPath, args.runtag); + } + + /** + * Process the fusion of TrecRun objects based on the specified method. + * + * @param runs List of TrecRun objects to be fused. + * @throws IOException If an I/O error occurs while saving the output. + */ + public void fuse(List runs) throws IOException { + TrecRun fusedRun; + + // Select fusion method + switch (args.method.toLowerCase()) { + case METHOD_RRF: + fusedRun = reciprocalRankFusion(runs, args.rrf_k, args.depth, args.k); + break; + case METHOD_INTERPOLATION: + fusedRun = interpolation(runs, args.alpha, args.depth, args.k); + break; + case METHOD_AVERAGE: + fusedRun = average(runs, args.depth, args.k); + break; + default: + throw new IllegalArgumentException("Unknown fusion method: " + args.method + + ". Supported methods are: average, rrf, interpolation."); + } + + saveToTxt(fusedRun); + } +}