From b646d10c115984ce09d844895d56b37844210044 Mon Sep 17 00:00:00 2001 From: ACreed Date: Sun, 21 Jan 2024 20:25:15 +0530 Subject: [PATCH] Fixed Haskell issues. (#32) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Moved the code to respective tree sitter file - Fixed signature and function separate nodes by merging - Fixed pattern match based function separate nodes by merging. --------- Co-authored-by: aravind.mallapureddy Co-authored-by: Fynn Flügge --- doc_comments_ai/app.py | 2 +- doc_comments_ai/llm.py | 25 +++++++-- doc_comments_ai/treesitter/treesitter.py | 14 +---- doc_comments_ai/treesitter/treesitter_hs.py | 61 ++++++++++++++++++++- doc_comments_ai/treesitter/treesitter_py.py | 2 +- tests/fixtures/code_fixture_c.py | 8 +-- tests/fixtures/code_fixture_hs.py | 4 ++ tests/fixtures/code_fixture_py.py | 2 +- tests/treesitter_query_test.py | 20 +++++-- 9 files changed, 107 insertions(+), 31 deletions(-) diff --git a/doc_comments_ai/app.py b/doc_comments_ai/app.py index 28025d0..397910d 100644 --- a/doc_comments_ai/app.py +++ b/doc_comments_ai/app.py @@ -102,7 +102,7 @@ def run(): if not input().lower() == "y": continue - method_source_code = node.node.text.decode() + method_source_code = node.method_source_code tokens = utils.count_tokens(method_source_code) if tokens > 2048 and not (args.gpt4 or args.gpt3_5_16k): diff --git a/doc_comments_ai/llm.py b/doc_comments_ai/llm.py index 1744954..97cf069 100644 --- a/doc_comments_ai/llm.py +++ b/doc_comments_ai/llm.py @@ -22,8 +22,8 @@ class LLM: def __init__( self, model: GptModel = GptModel.GPT_35, - local_model: str | None = None, - azure_deployment: str | None = None, + local_model: "str | None" = None, + azure_deployment: "str | None" = None, ): max_tokens = 2048 if model == GptModel.GPT_35 else 4096 if local_model is not None: @@ -50,11 +50,16 @@ def __init__( "The doc comment should describe what the method does. " "{inline_comments} " "Return the method implementaion with the doc comment as a markdown code block. " - "Don't include any explanations in your response." + "Don't include any explanations {haskell_missing_signature}in your response." ) self.prompt = PromptTemplate( template=self.template, - input_variables=["language", "code", "inline_comments"], + input_variables=[ + "language", + "code", + "inline_comments", + "Haskell_missing_signature", + ], ) self.chain = LLMChain(llm=self.llm, prompt=self.prompt) @@ -70,7 +75,17 @@ def generate_doc_comment(self, language, code, inline=False): else: inline_comments = "" - input = {"language": language, "code": code, "inline_comments": inline_comments} + if language == "haskell": + haskell_missing_signature = "and missing type signatures " + else: + haskell_missing_signature = "" + + input = { + "language": language, + "code": code, + "inline_comments": inline_comments, + "haskell_missing_signature": haskell_missing_signature, + } documented_code = self.chain.run(input) diff --git a/doc_comments_ai/treesitter/treesitter.py b/doc_comments_ai/treesitter/treesitter.py index b766ce4..b0535a4 100644 --- a/doc_comments_ai/treesitter/treesitter.py +++ b/doc_comments_ai/treesitter/treesitter.py @@ -12,11 +12,12 @@ def __init__( self, name: "str | bytes | None", doc_comment: "str | None", + method_source_code: "str | None", node: tree_sitter.Node, ): self.name = name self.doc_comment = doc_comment - self.method_source_code = node.text.decode() + self.method_source_code = method_source_code or node.text.decode() self.node = node @@ -46,7 +47,7 @@ def parse(self, file_bytes: bytes) -> list[TreesitterMethodNode]: method_name = self._query_method_name(method["method"]) doc_comment = method["doc_comment"] result.append( - TreesitterMethodNode(method_name, doc_comment, method["method"]) + TreesitterMethodNode(method_name, doc_comment, None, method["method"]) ) return result @@ -62,15 +63,6 @@ def _query_all_methods( and node.prev_named_sibling.type == self.doc_comment_identifier ): doc_comment_node = node.prev_named_sibling.text.decode() - else: - # added for haskell purpose. - if node.prev_named_sibling.type == "signature": - prev_node = node.prev_named_sibling - if ( - prev_node.prev_named_sibling - and prev_node.prev_named_sibling.type == self.doc_comment_identifier - ): - doc_comment_node = prev_node.prev_named_sibling.text.decode() methods.append({"method": node, "doc_comment": doc_comment_node}) else: for child in node.children: diff --git a/doc_comments_ai/treesitter/treesitter_hs.py b/doc_comments_ai/treesitter/treesitter_hs.py index 1ee345b..a9644b7 100644 --- a/doc_comments_ai/treesitter/treesitter_hs.py +++ b/doc_comments_ai/treesitter/treesitter_hs.py @@ -1,7 +1,10 @@ import tree_sitter +from typing import List, Dict + from doc_comments_ai.constants import Language -from doc_comments_ai.treesitter.treesitter import Treesitter +from doc_comments_ai.treesitter.treesitter import (Treesitter, + TreesitterMethodNode) from doc_comments_ai.treesitter.treesitter_registry import TreesitterRegistry @@ -11,8 +14,62 @@ def __init__(self): Language.HASKELL, "function", "variable", "comment" ) - def _query_method_name(self, node: tree_sitter.Node): + def parse(self, file_bytes: bytes) -> list[TreesitterMethodNode]: + self.tree = self.parser.parse(file_bytes) + result = [] + methods = self._query_all_methods(self.tree.root_node) + for method in methods: + method_name = self._query_method_name(method["method"]) + doc_comment = method["doc_comment"] + source_code = None + if method["method"].type == "signature": + sc = map(lambda x : "\n" + x.text.decode() if x.type == "function" else "", method["method"].children) + source_code = method["method"].text.decode() + "".join(sc) + result.append( + TreesitterMethodNode(method_name, doc_comment, source_code, method["method"]) + ) + return result + + def _query_all_methods( + self, + node: tree_sitter.Node, + ): + methods: List[Dict[tree_sitter.Node, tree_sitter.Node]] = [] if node.type == self.method_declaration_identifier: + doc_comment_node = None + if ( + node.prev_named_sibling + and node.prev_named_sibling.type == self.doc_comment_identifier + ): + doc_comment_node = node.prev_named_sibling.text.decode() + else: + if node.prev_named_sibling.type == "signature": + prev_node = node.prev_named_sibling + if ( + prev_node.prev_named_sibling + and prev_node.prev_named_sibling.type == self.doc_comment_identifier + ): + doc_comment_node = prev_node.prev_named_sibling.text.decode() + prev_node.children.append(node) + node = prev_node + methods.append({"method": node, "doc_comment": doc_comment_node}) + else: + for child in node.children: + current = self._query_all_methods(child) + if methods and current: + previous = methods[-1] + if self._query_method_name(previous["method"]) == self._query_method_name(current[0]["method"]): + previous["method"].children.extend(map(lambda x: x["method"], current)) + methods = methods[:-1] + methods.append(previous) + else: + methods.extend(current) + else: + methods.extend(current) + return methods + + def _query_method_name(self, node: tree_sitter.Node): + if node.type == "signature" or node.type == self.method_declaration_identifier: for child in node.children: if child.type == self.method_name_identifier: return child.text.decode() diff --git a/doc_comments_ai/treesitter/treesitter_py.py b/doc_comments_ai/treesitter/treesitter_py.py index 270f0c3..84b0637 100644 --- a/doc_comments_ai/treesitter/treesitter_py.py +++ b/doc_comments_ai/treesitter/treesitter_py.py @@ -19,7 +19,7 @@ def parse(self, file_bytes: bytes) -> list[TreesitterMethodNode]: for method in methods: method_name = self._query_method_name(method) doc_comment = self._query_doc_comment(method) - result.append(TreesitterMethodNode(method_name, doc_comment, method)) + result.append(TreesitterMethodNode(method_name, doc_comment, None, method)) return result def _query_method_name(self, node: tree_sitter.Node): diff --git a/tests/fixtures/code_fixture_c.py b/tests/fixtures/code_fixture_c.py index f80a8c6..562d785 100644 --- a/tests/fixtures/code_fixture_c.py +++ b/tests/fixtures/code_fixture_c.py @@ -24,7 +24,7 @@ def c_code_fixture(): const char *extension; enum Language language; }; - + struct LanguageMapping languageMapping[] = { {".py", PYTHON}, {".js", JAVASCRIPT}, @@ -33,16 +33,16 @@ def c_code_fixture(): {".kt", KOTLIN}, {".lua", LUA}, }; - + int numMappings = sizeof(languageMapping) / sizeof(languageMapping[0]); - + // Iterate through the mappings and check if the file extension matches. for (int i = 0; i < numMappings; i++) { if (strcmp(fileExtension, languageMapping[i].extension) == 0) { return languageMapping[i].language; } } - + return UNKNOWN; } diff --git a/tests/fixtures/code_fixture_hs.py b/tests/fixtures/code_fixture_hs.py index b35835d..a5c6411 100644 --- a/tests/fixtures/code_fixture_hs.py +++ b/tests/fixtures/code_fixture_hs.py @@ -33,4 +33,8 @@ def haskell_code_fixture(): getFileExtension fileName = let dot = dropWhile ((\=) '.') fileName in dot + +fromText :: Text -> Maybe Text +fromText "a" = Nothing +fromText x = Just x """ diff --git a/tests/fixtures/code_fixture_py.py b/tests/fixtures/code_fixture_py.py index 9aa8bae..8f4aafa 100644 --- a/tests/fixtures/code_fixture_py.py +++ b/tests/fixtures/code_fixture_py.py @@ -32,5 +32,5 @@ def get_programming_language(file_extension: str) -> Language: def get_file_extension(file_name: str) -> str: return os.path.splitext(file_name)[-1] - + """ diff --git a/tests/treesitter_query_test.py b/tests/treesitter_query_test.py index 658afed..3685873 100644 --- a/tests/treesitter_query_test.py +++ b/tests/treesitter_query_test.py @@ -359,7 +359,7 @@ def test_c_query(c_code_fixture): const char *extension; enum Language language; }; - + struct LanguageMapping languageMapping[] = { {".py", PYTHON}, {".js", JAVASCRIPT}, @@ -368,16 +368,16 @@ def test_c_query(c_code_fixture): {".kt", KOTLIN}, {".lua", LUA}, }; - + int numMappings = sizeof(languageMapping) / sizeof(languageMapping[0]); - + // Iterate through the mappings and check if the file extension matches. for (int i = 0; i < numMappings; i++) { if (strcmp(fileExtension, languageMapping[i].extension) == 0) { return languageMapping[i].language; } } - + return UNKNOWN; }""" ) @@ -477,7 +477,7 @@ def test_hs_query(haskell_code_fixture): haskell_code_fixture.encode() ) - assert treesitterNodes.__len__() == 2 + assert treesitterNodes.__len__() == 3 assert treesitterNodes[0].name == "getProgrammingLanguage" @@ -497,7 +497,8 @@ def test_hs_query(haskell_code_fixture): assert ( treesitterNodes[0].method_source_code - == """getProgrammingLanguage fileExtension = + == """getProgrammingLanguage :: String -> Language +getProgrammingLanguage fileExtension = let languageMapping = HM.insert ".py" PYTHON $ HM.insert ".js" JAVASCRIPT $ HM.insert ".ts" TYPESCRIPT @@ -508,3 +509,10 @@ def test_hs_query(haskell_code_fixture): Just v -> v Nothing -> UNKNOWN""" ) + + assert ( + treesitterNodes[2].method_source_code + == """fromText :: Text -> Maybe Text +fromText "a" = Nothing +fromText x = Just x""" + )