Skip to content

Commit

Permalink
Fixed Haskell issues. (#32)
Browse files Browse the repository at this point in the history
- Moved the code to respective tree sitter file
- Fixed signature and function separate nodes by merging
- Fixed pattern match based function separate nodes by merging.

---------

Co-authored-by: aravind.mallapureddy <[email protected]>
Co-authored-by: Fynn Flügge <[email protected]>
  • Loading branch information
3 people authored Jan 21, 2024
1 parent 6a634c6 commit b646d10
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 31 deletions.
2 changes: 1 addition & 1 deletion doc_comments_ai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def run():
if not input().lower() == "y":
continue

method_source_code = node.node.text.decode()
method_source_code = node.method_source_code

tokens = utils.count_tokens(method_source_code)
if tokens > 2048 and not (args.gpt4 or args.gpt3_5_16k):
Expand Down
25 changes: 20 additions & 5 deletions doc_comments_ai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class LLM:
def __init__(
self,
model: GptModel = GptModel.GPT_35,
local_model: str | None = None,
azure_deployment: str | None = None,
local_model: "str | None" = None,
azure_deployment: "str | None" = None,
):
max_tokens = 2048 if model == GptModel.GPT_35 else 4096
if local_model is not None:
Expand All @@ -50,11 +50,16 @@ def __init__(
"The doc comment should describe what the method does. "
"{inline_comments} "
"Return the method implementaion with the doc comment as a markdown code block. "
"Don't include any explanations in your response."
"Don't include any explanations {haskell_missing_signature}in your response."
)
self.prompt = PromptTemplate(
template=self.template,
input_variables=["language", "code", "inline_comments"],
input_variables=[
"language",
"code",
"inline_comments",
"Haskell_missing_signature",
],
)
self.chain = LLMChain(llm=self.llm, prompt=self.prompt)

Expand All @@ -70,7 +75,17 @@ def generate_doc_comment(self, language, code, inline=False):
else:
inline_comments = ""

input = {"language": language, "code": code, "inline_comments": inline_comments}
if language == "haskell":
haskell_missing_signature = "and missing type signatures "
else:
haskell_missing_signature = ""

input = {
"language": language,
"code": code,
"inline_comments": inline_comments,
"haskell_missing_signature": haskell_missing_signature,
}

documented_code = self.chain.run(input)

Expand Down
14 changes: 3 additions & 11 deletions doc_comments_ai/treesitter/treesitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ def __init__(
self,
name: "str | bytes | None",
doc_comment: "str | None",
method_source_code: "str | None",
node: tree_sitter.Node,
):
self.name = name
self.doc_comment = doc_comment
self.method_source_code = node.text.decode()
self.method_source_code = method_source_code or node.text.decode()
self.node = node


Expand Down Expand Up @@ -46,7 +47,7 @@ def parse(self, file_bytes: bytes) -> list[TreesitterMethodNode]:
method_name = self._query_method_name(method["method"])
doc_comment = method["doc_comment"]
result.append(
TreesitterMethodNode(method_name, doc_comment, method["method"])
TreesitterMethodNode(method_name, doc_comment, None, method["method"])
)
return result

Expand All @@ -62,15 +63,6 @@ def _query_all_methods(
and node.prev_named_sibling.type == self.doc_comment_identifier
):
doc_comment_node = node.prev_named_sibling.text.decode()
else:
# added for haskell purpose.
if node.prev_named_sibling.type == "signature":
prev_node = node.prev_named_sibling
if (
prev_node.prev_named_sibling
and prev_node.prev_named_sibling.type == self.doc_comment_identifier
):
doc_comment_node = prev_node.prev_named_sibling.text.decode()
methods.append({"method": node, "doc_comment": doc_comment_node})
else:
for child in node.children:
Expand Down
61 changes: 59 additions & 2 deletions doc_comments_ai/treesitter/treesitter_hs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import tree_sitter
from typing import List, Dict


from doc_comments_ai.constants import Language
from doc_comments_ai.treesitter.treesitter import Treesitter
from doc_comments_ai.treesitter.treesitter import (Treesitter,
TreesitterMethodNode)
from doc_comments_ai.treesitter.treesitter_registry import TreesitterRegistry


Expand All @@ -11,8 +14,62 @@ def __init__(self):
Language.HASKELL, "function", "variable", "comment"
)

def _query_method_name(self, node: tree_sitter.Node):
def parse(self, file_bytes: bytes) -> list[TreesitterMethodNode]:
self.tree = self.parser.parse(file_bytes)
result = []
methods = self._query_all_methods(self.tree.root_node)
for method in methods:
method_name = self._query_method_name(method["method"])
doc_comment = method["doc_comment"]
source_code = None
if method["method"].type == "signature":
sc = map(lambda x : "\n" + x.text.decode() if x.type == "function" else "", method["method"].children)
source_code = method["method"].text.decode() + "".join(sc)
result.append(
TreesitterMethodNode(method_name, doc_comment, source_code, method["method"])
)
return result

def _query_all_methods(
self,
node: tree_sitter.Node,
):
methods: List[Dict[tree_sitter.Node, tree_sitter.Node]] = []
if node.type == self.method_declaration_identifier:
doc_comment_node = None
if (
node.prev_named_sibling
and node.prev_named_sibling.type == self.doc_comment_identifier
):
doc_comment_node = node.prev_named_sibling.text.decode()
else:
if node.prev_named_sibling.type == "signature":
prev_node = node.prev_named_sibling
if (
prev_node.prev_named_sibling
and prev_node.prev_named_sibling.type == self.doc_comment_identifier
):
doc_comment_node = prev_node.prev_named_sibling.text.decode()
prev_node.children.append(node)
node = prev_node
methods.append({"method": node, "doc_comment": doc_comment_node})
else:
for child in node.children:
current = self._query_all_methods(child)
if methods and current:
previous = methods[-1]
if self._query_method_name(previous["method"]) == self._query_method_name(current[0]["method"]):
previous["method"].children.extend(map(lambda x: x["method"], current))
methods = methods[:-1]
methods.append(previous)
else:
methods.extend(current)
else:
methods.extend(current)
return methods

def _query_method_name(self, node: tree_sitter.Node):
if node.type == "signature" or node.type == self.method_declaration_identifier:
for child in node.children:
if child.type == self.method_name_identifier:
return child.text.decode()
Expand Down
2 changes: 1 addition & 1 deletion doc_comments_ai/treesitter/treesitter_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def parse(self, file_bytes: bytes) -> list[TreesitterMethodNode]:
for method in methods:
method_name = self._query_method_name(method)
doc_comment = self._query_doc_comment(method)
result.append(TreesitterMethodNode(method_name, doc_comment, method))
result.append(TreesitterMethodNode(method_name, doc_comment, None, method))
return result

def _query_method_name(self, node: tree_sitter.Node):
Expand Down
8 changes: 4 additions & 4 deletions tests/fixtures/code_fixture_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def c_code_fixture():
const char *extension;
enum Language language;
};
struct LanguageMapping languageMapping[] = {
{".py", PYTHON},
{".js", JAVASCRIPT},
Expand All @@ -33,16 +33,16 @@ def c_code_fixture():
{".kt", KOTLIN},
{".lua", LUA},
};
int numMappings = sizeof(languageMapping) / sizeof(languageMapping[0]);
// Iterate through the mappings and check if the file extension matches.
for (int i = 0; i < numMappings; i++) {
if (strcmp(fileExtension, languageMapping[i].extension) == 0) {
return languageMapping[i].language;
}
}
return UNKNOWN;
}
Expand Down
4 changes: 4 additions & 0 deletions tests/fixtures/code_fixture_hs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,8 @@ def haskell_code_fixture():
getFileExtension fileName =
let dot = dropWhile ((\=) '.') fileName
in dot
fromText :: Text -> Maybe Text
fromText "a" = Nothing
fromText x = Just x
"""
2 changes: 1 addition & 1 deletion tests/fixtures/code_fixture_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ def get_programming_language(file_extension: str) -> Language:
def get_file_extension(file_name: str) -> str:
return os.path.splitext(file_name)[-1]
"""
20 changes: 14 additions & 6 deletions tests/treesitter_query_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def test_c_query(c_code_fixture):
const char *extension;
enum Language language;
};
struct LanguageMapping languageMapping[] = {
{".py", PYTHON},
{".js", JAVASCRIPT},
Expand All @@ -368,16 +368,16 @@ def test_c_query(c_code_fixture):
{".kt", KOTLIN},
{".lua", LUA},
};
int numMappings = sizeof(languageMapping) / sizeof(languageMapping[0]);
// Iterate through the mappings and check if the file extension matches.
for (int i = 0; i < numMappings; i++) {
if (strcmp(fileExtension, languageMapping[i].extension) == 0) {
return languageMapping[i].language;
}
}
return UNKNOWN;
}"""
)
Expand Down Expand Up @@ -477,7 +477,7 @@ def test_hs_query(haskell_code_fixture):
haskell_code_fixture.encode()
)

assert treesitterNodes.__len__() == 2
assert treesitterNodes.__len__() == 3

assert treesitterNodes[0].name == "getProgrammingLanguage"

Expand All @@ -497,7 +497,8 @@ def test_hs_query(haskell_code_fixture):

assert (
treesitterNodes[0].method_source_code
== """getProgrammingLanguage fileExtension =
== """getProgrammingLanguage :: String -> Language
getProgrammingLanguage fileExtension =
let languageMapping = HM.insert ".py" PYTHON
$ HM.insert ".js" JAVASCRIPT
$ HM.insert ".ts" TYPESCRIPT
Expand All @@ -508,3 +509,10 @@ def test_hs_query(haskell_code_fixture):
Just v -> v
Nothing -> UNKNOWN"""
)

assert (
treesitterNodes[2].method_source_code
== """fromText :: Text -> Maybe Text
fromText "a" = Nothing
fromText x = Just x"""
)

0 comments on commit b646d10

Please sign in to comment.