Skip to content

Commit

Permalink
Apply formatter and more minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
timtebeek committed May 21, 2024
1 parent 4fca02e commit 0b56d73
Show file tree
Hide file tree
Showing 19 changed files with 163 additions and 165 deletions.
33 changes: 19 additions & 14 deletions src/main/java/io/moderne/ai/AgentGenerativeModelClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
package io.moderne.ai;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
Expand Down Expand Up @@ -76,7 +77,7 @@ public static synchronized AgentGenerativeModelClient getInstance() {
if (INSTANCE == null) {
//Check if llama.cpp is already built
File f = new File(pathToLLama + "/main");
if(!(f.exists() && !f.isDirectory()) ) {
if (!(f.exists() && !f.isDirectory())) {
//Build llama.cpp
StringWriter sw = new StringWriter();
PrintWriter procOut = new PrintWriter(sw);
Expand Down Expand Up @@ -106,12 +107,13 @@ public static synchronized AgentGenerativeModelClient getInstance() {
try {
Runtime runtime = Runtime.getRuntime();
Process proc_server = runtime.exec((new String[]
{"/bin/sh", "-c", pathToLLama + "/server -m " + pathToModel + " --port " + port }));
{"/bin/sh", "-c", pathToLLama + "/server -m " + pathToModel + " --port " + port}));

EXECUTOR_SERVICE.submit(() -> {new BufferedReader(new InputStreamReader(proc_server.getInputStream())).lines()
.forEach(procOut::println);
new BufferedReader(new InputStreamReader(proc_server.getErrorStream())).lines()
.forEach(procOut::println);
EXECUTOR_SERVICE.submit(() -> {
new BufferedReader(new InputStreamReader(proc_server.getInputStream())).lines()
.forEach(procOut::println);
new BufferedReader(new InputStreamReader(proc_server.getErrorStream())).lines()
.forEach(procOut::println);
});

if (!INSTANCE.checkForUp()) {
Expand All @@ -130,12 +132,13 @@ public static synchronized AgentGenerativeModelClient getInstance() {

private int checkForUpRequest() {
try {
HttpResponse<String> response = Unirest.head("http://127.0.0.1:"+port).asString();
HttpResponse<String> response = Unirest.head("http://127.0.0.1:" + port).asString();
return response.getStatus();
} catch (UnirestException e) {
return 523;
}
}

private boolean checkForUp() {
for (int i = 0; i < 60; i++) {
try {
Expand All @@ -152,7 +155,7 @@ private boolean checkForUp() {

public static void populateMethodsToSample(String pathToCenters) {
HashMap<String, String> tempMethodsToSample = new HashMap<>();
try (BufferedReader bufferedReader = new BufferedReader(new FileReader(pathToCenters))){
try (BufferedReader bufferedReader = new BufferedReader(new FileReader(pathToCenters))) {
String line;
String source;
String methodCall;
Expand Down Expand Up @@ -183,11 +186,11 @@ public ArrayList<String> getRecommendations(String code) {
while ((line = bufferedReader.readLine()) != null) {
promptContent.append(line).append("\n");
}
String text = "[INST]" + promptContent + code + "```\n[/INST]1." ;
String text = "[INST]" + promptContent + code + "```\n[/INST]1.";
HttpSender http = new HttpUrlConnectionSender(Duration.ofSeconds(20), Duration.ofSeconds(60));
HttpSender.Response raw;

HashMap <String, Object> input = new HashMap<>();
HashMap<String, Object> input = new HashMap<>();
input.put("stream", false);
input.put("prompt", text);
input.put("temperature", 0.5);
Expand All @@ -196,7 +199,7 @@ public ArrayList<String> getRecommendations(String code) {
try {
raw = http
.post("http://127.0.0.1:" + port + "/completion")
.withContent("application/json" ,
.withContent("application/json",
mapper.writeValueAsBytes(input)).send();
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
Expand Down Expand Up @@ -254,7 +257,7 @@ public boolean isRelated(String query, String code, double threshold) {
HttpSender http = new HttpUrlConnectionSender(Duration.ofSeconds(20), Duration.ofSeconds(60));
HttpSender.Response raw;

HashMap <String, Object> input = new HashMap<>();
HashMap<String, Object> input = new HashMap<>();
input.put("stream", false);
input.put("prompt", promptContent);
input.put("temperature", 0.0);
Expand All @@ -265,7 +268,7 @@ public boolean isRelated(String query, String code, double threshold) {
try {
raw = http
.post("http://127.0.0.1:" + port + "/completion")
.withContent("application/json" ,
.withContent("application/json",
mapper.writeValueAsBytes(input)).send();

} catch (JsonProcessingException e) {
Expand All @@ -290,9 +293,11 @@ private boolean parseRelated(String s) {
return (s.contains("Yes") || s.contains("yes"));

}

@Value
private static class LlamaResponse {
String content;

public String getResponse() {
return content;
}
Expand All @@ -315,7 +320,7 @@ public List<CompletionProbability> getCompletionProbabilities() {
public boolean isRelated(double threshold) {
for (CompletionProbability cp : completionProbabilities) {
if (cp.getContent().equals(" Yes")) {
return cp.getProbs().get(0).getProb()>=threshold;
return cp.getProbs().get(0).getProb() >= threshold;
}
}
return false;
Expand Down
1 change: 0 additions & 1 deletion src/main/java/io/moderne/ai/ClusteringClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,4 @@ public int[] getCenters() {
}



}
10 changes: 5 additions & 5 deletions src/main/java/io/moderne/ai/EmbeddingModelClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ protected boolean removeEldestEntry(java.util.Map.Entry<String, float[]> eldest)
}
}

public static synchronized EmbeddingModelClient getInstance() {
public static synchronized EmbeddingModelClient getInstance() {
if (INSTANCE == null) {
INSTANCE = new EmbeddingModelClient();
if (INSTANCE.checkForUpRequest() != 200) {
String cmd = String.format("/usr/bin/python3 'import gradio\ngradio.'", MODELS_DIR);
String cmd = "/usr/bin/python3 'import gradio\ngradio.'";
try {
Process proc = Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", cmd});
} catch (IOException e) {
Expand Down Expand Up @@ -168,18 +168,18 @@ private static double dist(float[] v1, float[] v2) {
float diff = v1[i] - v2[i];
sumOfSquaredDifferences += diff * diff;
}
return 1-Math.sqrt(sumOfSquaredDifferences);
return 1 - Math.sqrt(sumOfSquaredDifferences);
}

public float[] getEmbedding(String text) {
public float[] getEmbedding(String text) {

HttpSender http = new HttpUrlConnectionSender(Duration.ofSeconds(20), Duration.ofSeconds(30));
HttpSender.Response raw = null;

try {
raw = http
.post("http://127.0.0.1:7860/run/predict")
.withContent("application/json" , mapper.writeValueAsBytes(new EmbeddingModelClient.GradioRequest(text)))
.withContent("application/json", mapper.writeValueAsBytes(new EmbeddingModelClient.GradioRequest(text)))
.send();
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/io/moderne/ai/FindCommentsLanguage.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ public Space visitSpace(Space space, Space.Location loc, ExecutionContext ctx) {
if (comment instanceof TextComment) {
JavaSourceFile javaSourceFile = getCursor().firstEnclosing(JavaSourceFile.class);
distribution.insertRow(ctx, new LanguageDistribution.Row(
javaSourceFile.getSourcePath().toString(),
((TextComment) comment).getText(),
LanguageDetectorModelClient.getInstance().getLanguage(((TextComment) comment).getText()).getLanguage()
javaSourceFile.getSourcePath().toString(),
((TextComment) comment).getText(),
LanguageDetectorModelClient.getInstance().getLanguage(((TextComment) comment).getText()).getLanguage()
)
);

Expand Down
11 changes: 6 additions & 5 deletions src/main/java/io/moderne/ai/LanguageDetectorModelClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public class LanguageDetectorModelClient {
private static final ExecutorService EXECUTOR_SERVICE = Executors.newFixedThreadPool(3);
private static final Path MODELS_DIR = Paths.get(System.getProperty("user.home") + "/.moderne/models");

private ObjectMapper mapper = JsonMapper.builder()
private final ObjectMapper mapper = JsonMapper.builder()
.constructorDetector(ConstructorDetector.USE_PROPERTIES_BASED)
.build()
.registerModule(new ParameterNamesModule())
Expand All @@ -68,11 +68,11 @@ protected boolean removeEldestEntry(Map.Entry<Comment, String> eldest) {
}
}

public static synchronized LanguageDetectorModelClient getInstance() {
public static synchronized LanguageDetectorModelClient getInstance() {
if (INSTANCE == null) {
INSTANCE = new LanguageDetectorModelClient();
if (INSTANCE.checkForUpRequest() != 200) {
String cmd = String.format("/usr/bin/python3 'import gradio\ngradio.'", MODELS_DIR);
String cmd = "/usr/bin/python3 'import gradio\ngradio.'";
try {
Process proc = Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", cmd});
} catch (IOException e) {
Expand Down Expand Up @@ -153,7 +153,7 @@ private Function<Comment, String> timeLanguage(List<Duration> timings) {
}


public String getLanguageGradio(String text) {
public String getLanguageGradio(String text) {


HttpSender http = new HttpUrlConnectionSender(Duration.ofSeconds(20), Duration.ofSeconds(30));
Expand All @@ -162,7 +162,7 @@ public String getLanguageGradio(String text) {
try {
raw = http
.post("http://127.0.0.1:7861/run/predict")
.withContent("application/json" ,
.withContent("application/json",
mapper.writeValueAsBytes(new LanguageDetectorModelClient.GradioRequest(new String[]{text})))
.send();
} catch (JsonProcessingException e) {
Expand Down Expand Up @@ -191,6 +191,7 @@ private static class GradioRequest {
@Value
private static class GradioResponse {
String[] data;

public String getLanguage() {
return data[0];
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/io/moderne/ai/RelatedModelClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ protected boolean removeEldestEntry(java.util.Map.Entry<Related, Integer> eldest
}
}

public static synchronized RelatedModelClient getInstance() {
public static synchronized RelatedModelClient getInstance() {
if (INSTANCE == null) {
INSTANCE = new RelatedModelClient();
if (INSTANCE.checkForUpRequest() != 200) {
String cmd = String.format("/usr/bin/python3 'import gradio\ngradio.'", MODELS_DIR);
String cmd = "/usr/bin/python3 'import gradio\ngradio.'";
try {
Process proc = Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", cmd});
} catch (IOException e) {
Expand Down
9 changes: 4 additions & 5 deletions src/main/java/io/moderne/ai/SpellCheckCommentsInFrench.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ public Javadoc visitDocComment(Javadoc.DocComment javadoc, ExecutionContext ctx)
dc = dc.withBody(ListUtils.map(dc.getBody(), docLine -> {
if (docLine instanceof Javadoc.Text) {
String commentText = ((Javadoc.Text) docLine).getText();
if (!commentText.trim().isEmpty() && LanguageDetectorModelClient.getInstance()
.getLanguage(commentText).getLanguage().equals("fr")) {
if (!commentText.trim().isEmpty() && "fr".equals(LanguageDetectorModelClient.getInstance()
.getLanguage(commentText).getLanguage())) {
String fixedComment = SpellCheckerClient.getInstance().getCommentGradio(commentText);
if (!fixedComment.equals(commentText)) {
docLine = ((Javadoc.Text) docLine).withText(fixedComment);
Expand All @@ -78,9 +78,8 @@ public Space visitSpace(Space space, Space.Location loc, ExecutionContext ctx) {
if (c instanceof TextComment) {
TextComment tc = (TextComment) c;
String commentText = tc.getText();
if (!commentText.isEmpty() && LanguageDetectorModelClient.getInstance()
.getLanguage(commentText).getLanguage().equals("fr")
) {
if (!commentText.isEmpty() && "fr".equals(LanguageDetectorModelClient.getInstance()
.getLanguage(commentText).getLanguage())) {
String fixedComment = SpellCheckerClient.getInstance().getCommentGradio(commentText);
if (!fixedComment.equals(commentText)) {
return tc.withText(fixedComment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
@Override
public Xml.Comment visitComment(Xml.Comment comment, ExecutionContext ctx) {
String commentText = comment.getText();
if (!commentText.isEmpty() && LanguageDetectorModelClient.getInstance()
.getLanguage(commentText).getLanguage().equals("fr")
) {
if (!commentText.isEmpty() && "fr".equals(LanguageDetectorModelClient.getInstance()
.getLanguage(commentText).getLanguage())) {
String fixedComment = SpellCheckerClient.getInstance().getCommentGradio(commentText);
if (!fixedComment.equals(commentText)) {
return comment.withText(fixedComment);
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/io/moderne/ai/SpellCheckerClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public class SpellCheckerClient {
}
}

public static synchronized SpellCheckerClient getInstance() {
public static synchronized SpellCheckerClient getInstance() {
if (INSTANCE == null) {
INSTANCE = new SpellCheckerClient();
if (INSTANCE.checkForUpRequest() != 200) {
Expand Down Expand Up @@ -121,16 +121,15 @@ private int checkForUpRequest() {
}



public String getCommentGradio(String text) {
public String getCommentGradio(String text) {

HttpSender http = new HttpUrlConnectionSender(Duration.ofSeconds(20), Duration.ofSeconds(30));
HttpSender.Response raw = null;

try {
raw = http
.post("http://127.0.0.1:7866/run/predict")
.withContent("application/json" ,
.withContent("application/json",
mapper.writeValueAsBytes(new SpellCheckerClient.GradioRequest(new String[]{text})))
.send();
} catch (JsonProcessingException e) {
Expand Down Expand Up @@ -159,6 +158,7 @@ private static class GradioRequest {
@Value
private static class GradioResponse {
String[] data;

public String getSpellCheck() {
return data[0];
}
Expand Down
15 changes: 6 additions & 9 deletions src/main/java/io/moderne/ai/research/FindCodeThatResembles.java
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ public Accumulator getInitialValue(ExecutionContext ctx) {
public TreeVisitor<?, ExecutionContext> getScanner(Accumulator acc) {

return new JavaIsoVisitor<ExecutionContext>() {


private String extractTypeName(String fullyQualifiedTypeName) {
return fullyQualifiedTypeName.replace("<.*>", "")
.substring(fullyQualifiedTypeName.lastIndexOf('.') + 1);
Expand All @@ -133,20 +131,20 @@ private String extractTypeName(String fullyQualifiedTypeName) {
public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
cu.getTypesInUse().getUsedMethods().forEach(type -> {
String methodSignature = extractTypeName(Optional.ofNullable(type.getReturnType())
.map(Object::toString).orElse("")) + " " + type.getName() ;
.map(Object::toString).orElse("")) + " " + type.getName();

String[] parameters = new String[type.getParameterTypes().size()];
for (int i = 0; i < type.getParameterTypes().size(); i++) {
String typeName = extractTypeName(type.getParameterTypes().get(i).toString());
String paramName = type.getParameterNames().get(i);
parameters[i] = typeName + " " + paramName ;
parameters[i] = typeName + " " + paramName;
}

methodSignature += "(" + String.join(", ", parameters) + ")";
methodSignature += "(" + String.join(", ", parameters) + ")";

String methodPattern =
Optional.ofNullable(type.getDeclaringType()).map(Object::toString)
.orElse("") + " " + type.getName() + "(..)";
.orElse("") + " " + type.getName() + "(..)";

acc.add(methodSignature, methodPattern, resembles);
});
Expand All @@ -170,8 +168,7 @@ public TreeVisitor<?, ExecutionContext> getVisitor(Accumulator acc) {

@Override
public boolean isAcceptable(SourceFile sourceFile, ExecutionContext ctx) {
boolean acceptable = sourceFile instanceof J.CompilationUnit;
return acceptable;
return sourceFile instanceof J.CompilationUnit;
}

@Override
Expand Down Expand Up @@ -236,7 +233,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
method.printTrimmed(getCursor()),
resembles,
resultEmbeddingModels,
calledGenerativeModel ? ( result ? 1 : -1) : 0
calledGenerativeModel ? (result ? 1 : -1) : 0
));
}

Expand Down
Loading

0 comments on commit 0b56d73

Please sign in to comment.