diff --git a/src/main/java/io/moderne/ai/search/FindCodeThatResembles.java b/src/main/java/io/moderne/ai/search/FindCodeThatResembles.java index 5ea5ce9..d17d9b1 100644 --- a/src/main/java/io/moderne/ai/search/FindCodeThatResembles.java +++ b/src/main/java/io/moderne/ai/search/FindCodeThatResembles.java @@ -1,7 +1,7 @@ package io.moderne.ai.search; -import io.moderne.ai.table.EmbeddingPerformance; import io.moderne.ai.EmbeddingModelClient; +import io.moderne.ai.table.EmbeddingPerformance; import lombok.EqualsAndHashCode; import lombok.RequiredArgsConstructor; import org.openrewrite.*; @@ -18,6 +18,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import static java.util.Objects.requireNonNull; + @RequiredArgsConstructor @EqualsAndHashCode(callSuper = false) public class FindCodeThatResembles extends Recipe { @@ -72,12 +74,14 @@ public TreeVisitor getVisitor() { if (tree instanceof SourceFile) { getCursor().putMessage("count", new AtomicInteger()); getCursor().putMessage("max", new AtomicLong()); + getCursor().putMessage("histogram", new EmbeddingPerformance.Histogram()); J visit = super.visit(tree, ctx); if (getCursor().getMessage("count", new AtomicInteger()).get() > 0) { - Duration max = Duration.ofNanos(getCursor().getMessage("max", new AtomicLong()).get()); + Duration max = Duration.ofNanos(requireNonNull(getCursor().getMessage("max")).get()); performance.insertRow(ctx, new EmbeddingPerformance.Row(( (SourceFile) tree).getSourcePath().toString(), - getCursor().getMessage("count", new AtomicInteger()).get(), + requireNonNull(getCursor().getMessage("count")).get(), + requireNonNull(getCursor().getMessage("histogram")).getBuckets(), max)); } return visit; @@ -106,16 +110,16 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu EmbeddingModelClient.Relatedness related = modelClient.getRelatedness(resembles, method.printTrimmed(getCursor())); for (Duration timing : related.embeddingTimings()) { - getCursor().getNearestMessage("count", new AtomicInteger(0)).incrementAndGet(); - AtomicLong max = getCursor().getNearestMessage("max", new AtomicLong(0)); - if (max.get() < timing.toNanos()) { + requireNonNull(getCursor().getNearestMessage("count")).incrementAndGet(); + requireNonNull(getCursor().getNearestMessage("histogram")).add(timing); + AtomicLong max = getCursor().getNearestMessage("max"); + if (requireNonNull(max).get() < timing.toNanos()) { max.set(timing.toNanos()); } } - if (related.isRelated()) { - return SearchResult.found(method); - } - return super.visitMethodInvocation(method, ctx); + return related.isRelated() ? + SearchResult.found(method) : + super.visitMethodInvocation(method, ctx); } }); } diff --git a/src/main/java/io/moderne/ai/table/EmbeddingPerformance.java b/src/main/java/io/moderne/ai/table/EmbeddingPerformance.java index be8d4fb..dd2d474 100644 --- a/src/main/java/io/moderne/ai/table/EmbeddingPerformance.java +++ b/src/main/java/io/moderne/ai/table/EmbeddingPerformance.java @@ -1,11 +1,14 @@ package io.moderne.ai.table; -import lombok.Value; +import lombok.Getter; import org.openrewrite.Column; import org.openrewrite.DataTable; import org.openrewrite.Recipe; +import org.openrewrite.internal.lang.Nullable; import java.time.Duration; +import java.util.ArrayList; +import java.util.List; public class EmbeddingPerformance extends DataTable { @@ -15,18 +18,49 @@ public EmbeddingPerformance(Recipe recipe) { "Latency characteristics of uses of embedding models."); } - @Value - public static class Row { - @Column(displayName = "Source file", - description = "The source file that the method call occurred in.") - String sourceFile; + public static class Histogram { + private static final int BUCKETS = 100; + private static final long MAX_NANOS = (int) 1e9; - @Column(displayName = "Number of requests", - description = "The count of requests made to the model.") - int count; + @Getter + @Nullable + List buckets; - @Column(displayName = "Max latency", - description = "The maximum embedding latency.") - Duration max; + public Histogram() { + } + + public void add(Duration duration) { + int bucket = (int) (duration.toNanos() / (MAX_NANOS / BUCKETS)); + if (bucket < BUCKETS) { + if (buckets == null) { + buckets = new ArrayList<>(BUCKETS); + for (int i = 0; i < BUCKETS; i++) { + buckets.add(0); + } + } + buckets.set(bucket, buckets.get(bucket) + 1); + } + } + } + + public record Row( + @Column(displayName = "Source file", + description = "The source file that the method call occurred in.") + String sourceFile, + + @Column(displayName = "Number of requests", + description = "The count of requests made to the model.") + int count, + + @Column(displayName = "Histogram", + description = "The latency histogram of the requests made to the model (counts). " + + "The histogram is a non-cumulative fixed distribution of 100 buckets " + + "of 0.01 second each.") + @Nullable + List histogram, + + @Column(displayName = "Max latency", + description = "The maximum embedding latency.") + Duration max) { } } diff --git a/src/test/java/io/moderne/ai/table/EmbeddingPerformanceTest.java b/src/test/java/io/moderne/ai/table/EmbeddingPerformanceTest.java new file mode 100644 index 0000000..8b8d6dc --- /dev/null +++ b/src/test/java/io/moderne/ai/table/EmbeddingPerformanceTest.java @@ -0,0 +1,16 @@ +package io.moderne.ai.table; + +import org.junit.jupiter.api.Test; + +import java.time.Duration; + +public class EmbeddingPerformanceTest { + + @Test + void histogram() { + EmbeddingPerformance.Histogram histogram = new EmbeddingPerformance.Histogram(); + histogram.add(Duration.ofMillis(100)); + histogram.add(Duration.ofMillis(200)); + histogram.add(Duration.ofSeconds(2)); + } +}