Skip to content

Commit

Permalink
Histogram
Browse files Browse the repository at this point in the history
  • Loading branch information
jkschneider committed Aug 13, 2023
1 parent a403107 commit 5287f24
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 22 deletions.
24 changes: 14 additions & 10 deletions src/main/java/io/moderne/ai/search/FindCodeThatResembles.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.moderne.ai.search;

import io.moderne.ai.table.EmbeddingPerformance;
import io.moderne.ai.EmbeddingModelClient;
import io.moderne.ai.table.EmbeddingPerformance;
import lombok.EqualsAndHashCode;
import lombok.RequiredArgsConstructor;
import org.openrewrite.*;
Expand All @@ -18,6 +18,8 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

import static java.util.Objects.requireNonNull;

@RequiredArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class FindCodeThatResembles extends Recipe {
Expand Down Expand Up @@ -72,12 +74,14 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
if (tree instanceof SourceFile) {
getCursor().putMessage("count", new AtomicInteger());
getCursor().putMessage("max", new AtomicLong());
getCursor().putMessage("histogram", new EmbeddingPerformance.Histogram());
J visit = super.visit(tree, ctx);
if (getCursor().getMessage("count", new AtomicInteger()).get() > 0) {
Duration max = Duration.ofNanos(getCursor().getMessage("max", new AtomicLong()).get());
Duration max = Duration.ofNanos(requireNonNull(getCursor().<AtomicLong>getMessage("max")).get());
performance.insertRow(ctx, new EmbeddingPerformance.Row((
(SourceFile) tree).getSourcePath().toString(),
getCursor().getMessage("count", new AtomicInteger()).get(),
requireNonNull(getCursor().<AtomicInteger>getMessage("count")).get(),
requireNonNull(getCursor().<EmbeddingPerformance.Histogram>getMessage("histogram")).getBuckets(),
max));
}
return visit;
Expand Down Expand Up @@ -106,16 +110,16 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
EmbeddingModelClient.Relatedness related = modelClient.getRelatedness(resembles,
method.printTrimmed(getCursor()));
for (Duration timing : related.embeddingTimings()) {
getCursor().getNearestMessage("count", new AtomicInteger(0)).incrementAndGet();
AtomicLong max = getCursor().getNearestMessage("max", new AtomicLong(0));
if (max.get() < timing.toNanos()) {
requireNonNull(getCursor().<AtomicInteger>getNearestMessage("count")).incrementAndGet();
requireNonNull(getCursor().<EmbeddingPerformance.Histogram>getNearestMessage("histogram")).add(timing);
AtomicLong max = getCursor().getNearestMessage("max");
if (requireNonNull(max).get() < timing.toNanos()) {
max.set(timing.toNanos());
}
}
if (related.isRelated()) {
return SearchResult.found(method);
}
return super.visitMethodInvocation(method, ctx);
return related.isRelated() ?
SearchResult.found(method) :
super.visitMethodInvocation(method, ctx);
}
});
}
Expand Down
58 changes: 46 additions & 12 deletions src/main/java/io/moderne/ai/table/EmbeddingPerformance.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package io.moderne.ai.table;

import lombok.Value;
import lombok.Getter;
import org.openrewrite.Column;
import org.openrewrite.DataTable;
import org.openrewrite.Recipe;
import org.openrewrite.internal.lang.Nullable;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;

public class EmbeddingPerformance extends DataTable<EmbeddingPerformance.Row> {

Expand All @@ -15,18 +18,49 @@ public EmbeddingPerformance(Recipe recipe) {
"Latency characteristics of uses of embedding models.");
}

@Value
public static class Row {
@Column(displayName = "Source file",
description = "The source file that the method call occurred in.")
String sourceFile;
public static class Histogram {
private static final int BUCKETS = 100;
private static final long MAX_NANOS = (int) 1e9;

@Column(displayName = "Number of requests",
description = "The count of requests made to the model.")
int count;
@Getter
@Nullable
List<Integer> buckets;

@Column(displayName = "Max latency",
description = "The maximum embedding latency.")
Duration max;
public Histogram() {
}

public void add(Duration duration) {
int bucket = (int) (duration.toNanos() / (MAX_NANOS / BUCKETS));
if (bucket < BUCKETS) {
if (buckets == null) {
buckets = new ArrayList<>(BUCKETS);
for (int i = 0; i < BUCKETS; i++) {
buckets.add(0);
}
}
buckets.set(bucket, buckets.get(bucket) + 1);
}
}
}

public record Row(
@Column(displayName = "Source file",
description = "The source file that the method call occurred in.")
String sourceFile,

@Column(displayName = "Number of requests",
description = "The count of requests made to the model.")
int count,

@Column(displayName = "Histogram",
description = "The latency histogram of the requests made to the model (counts). " +
"The histogram is a non-cumulative fixed distribution of 100 buckets " +
"of 0.01 second each.")
@Nullable
List<Integer> histogram,

@Column(displayName = "Max latency",
description = "The maximum embedding latency.")
Duration max) {
}
}
16 changes: 16 additions & 0 deletions src/test/java/io/moderne/ai/table/EmbeddingPerformanceTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.moderne.ai.table;

import org.junit.jupiter.api.Test;

import java.time.Duration;

public class EmbeddingPerformanceTest {

@Test
void histogram() {
EmbeddingPerformance.Histogram histogram = new EmbeddingPerformance.Histogram();
histogram.add(Duration.ofMillis(100));
histogram.add(Duration.ofMillis(200));
histogram.add(Duration.ofSeconds(2));
}
}

0 comments on commit 5287f24

Please sign in to comment.