Skip to content

Commit

Permalink
[javasrc] keepTypeArguments Flag
Browse files Browse the repository at this point in the history
The desire for type arguments to be persisted in type nodes and properties has been around for a while.

This PR adds a hidden, off-by-default flag, `keep-type-arguments`, signals to the `TypeInfoCalculator` to build the full name with the full names of the type arguments.

Resolves #4488
  • Loading branch information
DavidBakerEffendi committed Apr 29, 2024
1 parent deec496 commit 151c6df
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ case class JavaSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends C
}

override def applyPostProcessingPasses(cpg: Cpg): Cpg = {
if (javaConfig.forall(_.enableTypeRecovery))
if (javaConfig.exists(_.enableTypeRecovery))
JavaSrc2Cpg.typeRecoveryPasses(cpg, javaConfig).foreach(_.createAndApply())
super.applyPostProcessingPasses(cpg)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ final case class Config(
showEnv: Boolean = false,
skipTypeInfPass: Boolean = false,
dumpJavaparserAsts: Boolean = false,
cacheJdkTypeSolver: Boolean = false
cacheJdkTypeSolver: Boolean = false,
keepTypeArguments: Boolean = true
) extends X2CpgConfig[Config]
with TypeRecoveryParserConfig[Config] {
def withInferenceJarPaths(paths: Set[String]): Config = {
Expand Down Expand Up @@ -67,6 +68,10 @@ final case class Config(
def withCacheJdkTypeSolver(value: Boolean): Config = {
copy(cacheJdkTypeSolver = value).withInheritedFields(this)
}

def withKeepTypeArguments(value: Boolean): Config = {
copy(keepTypeArguments = value).withInheritedFields(this)
}
}

private object Frontend {
Expand Down Expand Up @@ -120,7 +125,11 @@ private object Frontend {
opt[Unit]("cache-jdk-type-solver")
.hidden()
.action((_, c) => c.withCacheJdkTypeSolver(true))
.text("Re-use JDK type solver between scans.")
.text("Re-use JDK type solver between scans."),
opt[Unit]("keep-type-arguments")
.hidden()
.action((_, c) => c.withKeepTypeArguments(true))
.text("Type full names of variables keep their type arguments.")
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class AstCreator(
javaParserAst: CompilationUnit,
fileContent: Option[String],
global: Global,
val symbolSolver: JavaSymbolSolver
val symbolSolver: JavaSymbolSolver,
keepTypeArguments: Boolean
)(implicit val withSchemaValidation: ValidationMode)
extends AstCreatorBase(filename)
with AstNodeBuilder[Node, AstCreator]
Expand All @@ -96,8 +97,9 @@ class AstCreator(

private[astcreation] val scope = Scope()

private[astcreation] val typeInfoCalc: TypeInfoCalculator = TypeInfoCalculator(global, symbolSolver)
private[astcreation] val bindingTableCache = mutable.HashMap.empty[String, BindingTable]
private[astcreation] val typeInfoCalc: TypeInfoCalculator =
TypeInfoCalculator(global, symbolSolver, keepTypeArguments)
private[astcreation] val bindingTableCache = mutable.HashMap.empty[String, BindingTable]

/** Entry point of AST creation. Translates a compilation unit created by JavaParser into a DiffGraph containing the
* corresponding CPG AST.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ class AstCreationPass(config: Config, cpg: Cpg, sourcesOverride: Option[List[Str
symbolSolver.inject(compilationUnit)
val contentToUse = if (!config.disableFileContent) fileContent else None
diffGraph.absorb(
new AstCreator(filename, compilationUnit, contentToUse, global, symbolSolver)(config.schemaValidation)
new AstCreator(filename, compilationUnit, contentToUse, global, symbolSolver, config.keepTypeArguments)(
config.schemaValidation
)
.createAst()
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@ import com.github.javaparser.resolution.declarations.{
ResolvedTypeDeclaration,
ResolvedTypeParameterDeclaration
}
import com.github.javaparser.resolution.types._
import com.github.javaparser.resolution.types.parametrization.ResolvedTypeParametersMap
import com.github.javaparser.resolution.logic.InferenceVariableType
import com.github.javaparser.resolution.model.typesystem.{LazyType, NullType}
import com.github.javaparser.resolution.types.*
import com.github.javaparser.resolution.types.parametrization.ResolvedTypeParametersMap
import io.joern.javasrc2cpg.typesolvers.TypeInfoCalculator.{TypeConstants, TypeNameConstants}
import io.joern.x2cpg.datastructures.Global
import org.slf4j.LoggerFactory

import scala.jdk.CollectionConverters._
import scala.collection.mutable
import scala.jdk.CollectionConverters.*
import scala.jdk.OptionConverters.RichOptional
import scala.util.Try

class TypeInfoCalculator(global: Global, symbolResolver: SymbolResolver) {
class TypeInfoCalculator(global: Global, symbolResolver: SymbolResolver, keepTypeArguments: Boolean) {
private val logger = LoggerFactory.getLogger(this.getClass)
private val emptyTypeParamValues = ResolvedTypeParametersMap.empty()

Expand Down Expand Up @@ -71,6 +72,12 @@ class TypeInfoCalculator(global: Global, symbolResolver: SymbolResolver) {
fullyQualified: Boolean
): Option[String] = {
typ match {
case refType: ResolvedReferenceType if keepTypeArguments =>
val typeParams = refType.getTypeParametersMap.asScala.map(_.b).map(fullName(_).getOrElse(TypeConstants.Object))
nameOrFullName(refType.getTypeDeclaration.get, fullyQualified).map {
case baseType if typeParams.isEmpty => baseType
case baseType => s"$baseType<${typeParams.mkString(",")}>"
}
case refType: ResolvedReferenceType =>
nameOrFullName(refType.getTypeDeclaration.get, fullyQualified)
case lazyType: LazyType =>
Expand Down Expand Up @@ -289,7 +296,4 @@ object TypeInfoCalculator {
"java.lang.Boolean"
)

def apply(global: Global, symbolResolver: SymbolResolver): TypeInfoCalculator = {
new TypeInfoCalculator(global, symbolResolver)
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package io.joern.javasrc2cpg.querying

import io.joern.javasrc2cpg.Config
import io.joern.javasrc2cpg.testfixtures.JavaSrcCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, Local}
import io.shiftleft.semanticcpg.language._
import io.shiftleft.semanticcpg.language.*

class VarDeclTests extends JavaSrcCode2CpgFixture {

Expand Down Expand Up @@ -155,4 +156,31 @@ class VarDeclTests extends JavaSrcCode2CpgFixture {
assigX.code shouldBe "x = 1"
assigX.order shouldBe 6
}

"generics with 'keep type arguments' config" should {
val cpg = code("""
|import java.util.ArrayList;
|import java.util.List;
|import java.util.HashMap;
|
|public class Main {
| public static void main(String[] args) {
| // Create a List of Strings
| List<String> stringList = new ArrayList<>();
| var stringIntMap = new HashMap<String, Integer>();
| }
|}
|
|""".stripMargin)
.withConfig(Config().withKeepTypeArguments(true))

"show the fully qualified type arguments for `List`" in {
cpg.identifier("stringList").typeFullName.head shouldBe "java.util.List<java.lang.String>"
}

"show the fully qualified type arguments for `Map`" in {
cpg.identifier("stringIntMap").typeFullName.head shouldBe "java.util.HashMap<java.lang.String,java.lang.Integer>"
}

}
}

0 comments on commit 151c6df

Please sign in to comment.