Skip to content

Commit

Permalink
Merge pull request #86 from openrewrite/optimize-speed-search
Browse files Browse the repository at this point in the history
Optimize speed search
  • Loading branch information
justine-gehring authored Jul 29, 2024
2 parents 4d00f40 + 69b1e7d commit 0ceb224
Showing 1 changed file with 35 additions and 28 deletions.
63 changes: 35 additions & 28 deletions src/main/java/io/moderne/ai/research/FindCodeThatResembles.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ public class FindCodeThatResembles extends ScanningRecipe<FindCodeThatResembles.
transient CodeSearch codeSearchTable = new CodeSearch(this);
transient TopKMethodMatcher topKTable = new TopKMethodMatcher(this);
transient EmbeddingPerformance performance = new EmbeddingPerformance(this);
static Boolean populatedTopKDataTable = false;


@Override
Expand All @@ -84,9 +83,13 @@ private static class MethodSignatureWithDistance {
@Value
@RequiredArgsConstructor
public static class Accumulator {
int k;
@NonFinal
@Nullable
Boolean populatedTopKDataTable = false;
final int k;
PriorityQueue<MethodSignatureWithDistance> methodSignaturesQueue = new PriorityQueue<>(Comparator.comparingDouble(MethodSignatureWithDistance::getDistance));
EmbeddingModelClient embeddingModelClient = EmbeddingModelClient.getInstance();
private HashSet<String> methodPatternsSet = new HashSet<>();

@NonFinal
@Nullable
Expand All @@ -97,16 +100,18 @@ public static class Accumulator {
List<MethodSignatureWithDistance> topMethodSignatureWithDistances;

public void add(String methodSignature, String methodPattern, String resembles) {
for (MethodSignatureWithDistance entry : methodSignaturesQueue) {
if (entry.methodPattern.equals(methodPattern)) {
return;
}
if (methodPatternsSet.contains(methodPattern)) {
return;
}

MethodSignatureWithDistance methodSignatureWithDistance = new MethodSignatureWithDistance(methodSignature,
MethodSignatureWithDistance methodSignatureWithDistance = new MethodSignatureWithDistance(
methodSignature,
methodPattern,
(float) embeddingModelClient.getDistance(resembles, methodSignature));
(float) embeddingModelClient.getDistance(resembles, methodSignature)
);

methodSignaturesQueue.add(methodSignatureWithDistance);
methodPatternsSet.add(methodPattern);
}

public List<MethodSignatureWithDistance> getTopMethodSignatureWithDistances() {
Expand Down Expand Up @@ -135,6 +140,10 @@ public List<MethodMatcher> populateTopK() {
}
return topMethodPatterns;
}

public void setPopulatedTopKDataTable(boolean b) {
this.populatedTopKDataTable = b;
}
}

@Override
Expand All @@ -155,21 +164,28 @@ private String extractTypeName(String fullyQualifiedTypeName) {
@Override
public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
cu.getTypesInUse().getUsedMethods().forEach(type -> {
String methodSignature = extractTypeName(Optional.ofNullable(type.getReturnType())
.map(Object::toString).orElse("")) + " " + type.getName();
StringBuilder methodSignatureBuilder = new StringBuilder();
StringBuilder methodPatternBuilder = new StringBuilder();

String methodSignature = methodSignatureBuilder.append(extractTypeName(Optional.ofNullable(type.getReturnType())
.map(Object::toString).orElse(""))).append(" ").append(type.getName()).toString();

methodSignatureBuilder.setLength(0); // Clear the builder for reuse

String[] parameters = new String[type.getParameterTypes().size()];
for (int i = 0; i < type.getParameterTypes().size(); i++) {
String typeName = extractTypeName(type.getParameterTypes().get(i).toString());
String paramName = type.getParameterNames().get(i);
parameters[i] = typeName + " " + paramName;
methodSignatureBuilder.append(typeName).append(" ").append(paramName);
if (i < type.getParameterTypes().size() - 1) {
methodSignatureBuilder.append(", ");
}
}

methodSignature += "(" + String.join(", ", parameters) + ")";
methodSignature += "(" + methodSignatureBuilder.toString() + ")";

String methodPattern =
Optional.ofNullable(type.getDeclaringType()).map(Object::toString)
.orElse("") + " " + type.getName() + "(..)";
methodPatternBuilder.setLength(0); // Clear the builder for reuse
String methodPattern = methodPatternBuilder.append(Optional.ofNullable(type.getDeclaringType())
.map(Object::toString).orElse("")).append(" ").append(type.getName()).append("(..)").toString();

acc.add(methodSignature, methodPattern, resembles);
});
Expand Down Expand Up @@ -219,7 +235,7 @@ public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionCon
@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {

if (!populatedTopKDataTable) {
if (!acc.populatedTopKDataTable) {
List<MethodSignatureWithDistance> methodMatchersDistance = acc.getTopMethodSignatureWithDistances();
for (MethodSignatureWithDistance methodSignatureWithDistance : methodMatchersDistance) {
topKTable.insertRow(ctx, new TopKMethodMatcher.Row(
Expand All @@ -229,23 +245,14 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
resembles
));
}

populatedTopKDataTable = true;
}

boolean matches = false;
for (MethodMatcher methodMatcher : methodMatchers) {
if (methodMatcher.matches(method)) {
matches = true;
break;
}
acc.setPopulatedTopKDataTable(true);
}

boolean matches = methodMatchers.stream().anyMatch(matcher -> matcher.matches(method));
if (!matches) {
return super.visitMethodInvocation(method, ctx);
}


RelatedModelClient.Relatedness related = RelatedModelClient.getInstance()
.getRelatedness(resembles, method.printTrimmed(getCursor()));
for (Duration timing : related.getEmbeddingTimings()) {
Expand Down

0 comments on commit 0ceb224

Please sign in to comment.