Skip to content

Commit

Permalink
kotlin2cpg: move ast creation logic into traits (#3515)
Browse files Browse the repository at this point in the history
  • Loading branch information
ursachec authored Aug 16, 2023
1 parent fabf155 commit a406ee3
Show file tree
Hide file tree
Showing 7 changed files with 2,594 additions and 2,440 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,41 @@ package io.joern.kotlin2cpg.ast

import io.joern.kotlin2cpg.Constants
import io.joern.kotlin2cpg.KtFileWithMeta
import io.joern.kotlin2cpg.ast.Nodes.{namespaceBlockNode, operatorCallNode}
import io.joern.kotlin2cpg.types.{TypeConstants, TypeInfoProvider, TypeRenderer}
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.*
import io.shiftleft.passes.IntervalKeyPool
import io.joern.x2cpg.{Ast, AstCreatorBase, AstNodeBuilder, ValidationMode}
import io.joern.x2cpg.{Ast, AstCreatorBase, AstNodeBuilder, Defines, ValidationMode}
import io.joern.x2cpg.datastructures.Global
import io.joern.x2cpg.datastructures.Stack.*
import io.joern.kotlin2cpg.datastructures.Scope
import io.joern.x2cpg.utils.NodeBuilders.newMethodReturnNode
import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal
import org.jetbrains.kotlin.com.intellij.psi.PsiElement
import org.jetbrains.kotlin.descriptors.{DescriptorVisibilities, DescriptorVisibility}
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.lexer.{KtToken, KtTokens}
import org.slf4j.{Logger, LoggerFactory}
import overflowdb.BatchedUpdate.DiffGraphBuilder

import scala.annotation.tailrec
import scala.collection.mutable
import scala.jdk.CollectionConverters.CollectionHasAsScala
import io.shiftleft.semanticcpg.language.*

import scala.jdk.CollectionConverters.*

case class BindingInfo(node: NewBinding, edgeMeta: Seq[(NewNode, NewNode, String)])
case class ClosureBindingDef(node: NewClosureBinding, captureEdgeTo: NewMethodRef, refEdgeTo: NewNode)

class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvider, global: Global)(implicit
withSchemaValidation: ValidationMode
) extends AstCreatorBase(fileWithMeta.filename)
with KtPsiToAst
with AstForDeclarationsCreator
with AstForPrimitivesCreator
with AstForFunctionsCreator
with AstForStatementsCreator
with AstForExpressionsCreator
with AstNodeBuilder[PsiElement, AstCreator] {

protected val closureBindingDefQueue: mutable.ArrayBuffer[ClosureBindingDef] = mutable.ArrayBuffer.empty
Expand Down Expand Up @@ -243,4 +253,180 @@ class AstCreator(fileWithMeta: KtFileWithMeta, xTypeInfoProvider: TypeInfoProvid
Seq(astForUnknown(unknownExpr, argIdxMaybe, argNameMaybe, annotations))
}
}

def astForFile(fileWithMeta: KtFileWithMeta)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val ktFile = fileWithMeta.f

val importDirectives = ktFile.getImportList.getImports.asScala
val importAsts = importDirectives.toList.map(astForImportDirective)
val namespaceBlocksForImports =
for {
node <- importAsts.flatMap(_.root.collectAll[NewImport])
name = getName(node)
} yield Ast(namespaceBlockNode(name, name, relativizedPath))

val packageName = ktFile.getPackageFqName.toString
val node =
if (packageName == Constants.root)
namespaceBlockNode(
NamespaceTraversal.globalNamespaceName,
NamespaceTraversal.globalNamespaceName,
relativizedPath
)
else {
val name = packageName.split("\\.").lastOption.getOrElse("")
namespaceBlockNode(name, packageName, relativizedPath)
}
methodAstParentStack.push(node)

val name = NamespaceTraversal.globalNamespaceName
val fullName = node.fullName
val fakeGlobalTypeDecl =
typeDeclNode(ktFile, name, fullName, relativizedPath, name, NodeTypes.NAMESPACE_BLOCK, fullName)
methodAstParentStack.push(fakeGlobalTypeDecl)

val fakeGlobalMethod =
methodNode(ktFile, name, name, fullName, None, relativizedPath, Option(NodeTypes.TYPE_DECL), Option(fullName))
methodAstParentStack.push(fakeGlobalMethod)
scope.pushNewScope(fakeGlobalMethod)

val blockNode_ = blockNode(ktFile, "<empty>", registerType(TypeConstants.any))
val methodReturn = newMethodReturnNode(TypeConstants.any, None, None, None)

val declarationsAsts = ktFile.getDeclarations.asScala.flatMap(astsForDeclaration)
val fileNode = NewFile().name(fileWithMeta.relativizedPath)
val lambdaTypeDecls =
lambdaBindingInfoQueue.flatMap(_.edgeMeta.collect { case (node: NewTypeDecl, _, _) => Ast(node) })
methodAstParentStack.pop()

val allDeclarationAsts = declarationsAsts ++ lambdaAstQueue ++ lambdaTypeDecls
val fakeTypeDeclAst =
Ast(fakeGlobalTypeDecl)
.withChild(
methodAst(fakeGlobalMethod, Seq.empty, blockAst(blockNode_, allDeclarationAsts.toList), methodReturn)
)
val namespaceBlockAst =
Ast(node).withChildren(importAsts).withChild(fakeTypeDeclAst)
Ast(fileNode).withChildren(namespaceBlockAst :: namespaceBlocksForImports)
}

def astsForDeclaration(decl: KtDeclaration)(implicit typeInfoProvider: TypeInfoProvider): Seq[Ast] = {
decl match {
case c: KtClass => astsForClassOrObject(c)
case o: KtObjectDeclaration => astsForClassOrObject(o)
case n: KtNamedFunction =>
val isExtensionFn = typeInfoProvider.isExtensionFn(n)
astsForMethod(n, isExtensionFn)
case t: KtTypeAlias => Seq(astForTypeAlias(t))
case s: KtSecondaryConstructor => Seq(astForUnknown(s, None, None))
case p: KtProperty => astsForProperty(p)
case unhandled =>
logger.error(
s"Unknown declaration type encountered with text `${unhandled.getText}` and class `${unhandled.getClass}`!"
)
Seq()
}
}

def astForUnknown(
expr: KtExpression,
argIdx: Option[Int],
argNameMaybe: Option[String],
annotations: Seq[KtAnnotationEntry] = Seq()
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val node = unknownNode(expr, Option(expr).map(_.getText).getOrElse(Constants.codePropUndefinedValue))
Ast(withArgumentIndex(node, argIdx).argumentName(argNameMaybe))
.withChildren(annotations.map(astForAnnotationEntry))
}

protected def assignmentAstForDestructuringEntry(
entry: KtDestructuringDeclarationEntry,
componentNReceiverName: String,
componentNTypeFullName: String,
componentIdx: Integer
)(implicit typeInfoProvider: TypeInfoProvider): Ast = {
val entryTypeFullName = registerType(typeInfoProvider.typeFullName(entry, TypeConstants.any))
val assignmentLHSNode = identifierNode(entry, entry.getText, entry.getText, entryTypeFullName)
val assignmentLHSAst = astWithRefEdgeMaybe(assignmentLHSNode.name, assignmentLHSNode)

val componentNIdentifierNode =
identifierNode(entry, componentNReceiverName, componentNReceiverName, componentNTypeFullName)
.argumentIndex(0)

val fallbackSignature = s"${Defines.UnresolvedNamespace}()"
val fallbackFullName =
s"${Defines.UnresolvedNamespace}${Constants.componentNPrefix}$componentIdx:$fallbackSignature"
val (fullName, signature) =
typeInfoProvider.fullNameWithSignature(entry, (fallbackFullName, fallbackSignature))
val componentNCallCode = s"$componentNReceiverName.${Constants.componentNPrefix}$componentIdx()"
val componentNCallNode = callNode(
entry,
componentNCallCode,
s"${Constants.componentNPrefix}$componentIdx",
fullName,
DispatchTypes.DYNAMIC_DISPATCH,
Some(signature),
Some(entryTypeFullName)
)

val componentNIdentifierAst = astWithRefEdgeMaybe(componentNIdentifierNode.name, componentNIdentifierNode)
val componentNAst =
callAst(componentNCallNode, Seq(), Option(componentNIdentifierAst))

val assignmentCallNode = operatorCallNode(
Operators.assignment,
s"${entry.getText} = $componentNCallCode",
None,
line(entry),
column(entry)
)
callAst(assignmentCallNode, List(assignmentLHSAst, componentNAst))
}

protected def astDerivedFullNameWithSignature(expr: KtQualifiedExpression, argAsts: List[Ast])(implicit
typeInfoProvider: TypeInfoProvider
): (String, String) = {
val astDerivedMethodFullName = expr.getSelectorExpression match {
case expression: KtCallExpression =>
val receiverPlaceholderType = Defines.UnresolvedNamespace
val shortName = expr.getSelectorExpression.getFirstChild.getText
val args = expression.getValueArguments
s"$receiverPlaceholderType.$shortName:${typeInfoProvider.anySignature(args.asScala.toList)}"
case _: KtNameReferenceExpression =>
Operators.fieldAccess
case _ =>
// TODO: add more test cases for this scenario
""
}

val astDerivedSignature = typeInfoProvider.anySignature(argAsts)
(astDerivedMethodFullName, astDerivedSignature)
}

protected def selectorExpressionArgAsts(
expr: KtQualifiedExpression
)(implicit typeInfoProvider: TypeInfoProvider): List[Ast] = {
expr.getSelectorExpression match {
case typedExpr: KtCallExpression =>
withIndex(typedExpr.getValueArguments.asScala.toSeq) { case (arg, idx) =>
astsForExpression(arg.getArgumentExpression, Some(idx))
}.flatten.toList
case typedExpr: KtNameReferenceExpression =>
val node = fieldIdentifierNode(typedExpr, typedExpr.getText, typedExpr.getText).argumentIndex(2)
List(Ast(node))
case _ => List()
}
}

protected def modifierTypeForVisibility(visibility: DescriptorVisibility): String = {
if (visibility.toString == DescriptorVisibilities.PUBLIC.toString)
ModifierTypes.PUBLIC
else if (visibility.toString == DescriptorVisibilities.PRIVATE.toString)
ModifierTypes.PRIVATE
else if (visibility.toString == DescriptorVisibilities.PROTECTED.toString)
ModifierTypes.PROTECTED
else if (visibility.toString == DescriptorVisibilities.INTERNAL.toString)
ModifierTypes.INTERNAL
else "UNKNOWN"
}
}
Loading

0 comments on commit a406ee3

Please sign in to comment.