Skip to content

Commit

Permalink
Merge pull request #73 from openrewrite/final-touches-ai-code-search
Browse files Browse the repository at this point in the history
add datatable to debug accuracy and only check method invocations in …
  • Loading branch information
justine-gehring authored May 15, 2024
2 parents 8322681 + b651896 commit 5bff3da
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/main/java/io/moderne/ai/research/FindCodeThatResembles.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.moderne.ai.AgentGenerativeModelClient;
import io.moderne.ai.EmbeddingModelClient;
import io.moderne.ai.RelatedModelClient;
import io.moderne.ai.table.CodeSearch;
import io.moderne.ai.table.EmbeddingPerformance;
import lombok.EqualsAndHashCode;
import lombok.RequiredArgsConstructor;
Expand All @@ -29,6 +30,7 @@
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaSourceFile;
import org.openrewrite.marker.SearchResult;

import java.time.Duration;
Expand All @@ -54,6 +56,7 @@ public class FindCodeThatResembles extends ScanningRecipe<FindCodeThatResembles.
example = "1000")
int k;

transient CodeSearch codeSearchTable = new CodeSearch(this);

transient EmbeddingPerformance performance = new EmbeddingPerformance(this);

Expand Down Expand Up @@ -185,6 +188,11 @@ public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionCon

@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {

if (!getLanguage().equals("java")) {
return super.visitMethodInvocation(method, ctx);
}

boolean matches = false;
for (MethodMatcher methodMatcher : methodMatchers) {
if (methodMatcher.matches(method)) {
Expand All @@ -208,11 +216,25 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
}
int resultEmbeddingModels = related.isRelated();
boolean result;
boolean calledGenerativeModel = false;
if (resultEmbeddingModels == 0) {
result = AgentGenerativeModelClient.getInstance().isRelated(resembles, method.printTrimmed(getCursor()), 0.5932);
calledGenerativeModel = true;
} else {
result = resultEmbeddingModels == 1;
}

// Populate data table for debugging model's accuracy
JavaSourceFile javaSourceFile = getCursor().firstEnclosing(JavaSourceFile.class);
String source = javaSourceFile.getSourcePath().toString();
codeSearchTable.insertRow(ctx, new CodeSearch.Row(
source,
method.printTrimmed(getCursor()),
resembles,
resultEmbeddingModels,
calledGenerativeModel ? ( result ? 1 : -1) : 0
));

return result ?
SearchResult.found(method) :
super.visitMethodInvocation(method, ctx);
Expand Down
57 changes: 57 additions & 0 deletions src/main/java/io/moderne/ai/table/CodeSearch.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2021 the original author or authors.
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* https://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.moderne.ai.table;

import lombok.Value;
import org.openrewrite.Column;
import org.openrewrite.DataTable;
import org.openrewrite.Recipe;


public class CodeSearch extends DataTable<CodeSearch.Row> {

public CodeSearch(Recipe recipe) {
super(recipe,
"Code Search",
"Searches for method invocations that resemble a natural language query.");
}

@Value
public static class Row {
@Column(displayName = "Source",
description = "Source")
String source;

@Column(displayName = "Method",
description = "Method invocation")
String method;

@Column(displayName = "Query",
description = "Natural language query")
String query;

@Column(displayName = "Result of first models",
description = "First two embeddings models result," +
" where -1 means negative match, 0 means unsure, and 1 means positive match.")
int resultEmbedding;

@Column(displayName = "Result of second model",
description = "Second generative model's result," +
" where -1 means negative match and 1 means positive match. " +
"If the model was never queried, then the result is 0.")
int resultGenerative;
}
}

0 comments on commit 5bff3da

Please sign in to comment.