Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ruby] handle private_class_method and public_class_method access modifiers #5074

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,25 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
if isClosure || isSingletonObjectMethod then refs else createMethodRefPointer(method) :: Nil
}

protected def astForMethodAccessModifier(node: MethodAccessModifier): Seq[Ast] = {
val originalAccessModifier = currentAccessModifier
popAccessModifier()

node match {
case _: PrivateMethodModifier =>
pushAccessModifier(ModifierTypes.PRIVATE)
case _: PublicMethodModifier =>
pushAccessModifier(ModifierTypes.PUBLIC)
}

val methodAst = astsForStatement(node.method)

popAccessModifier()
pushAccessModifier(originalAccessModifier)

methodAst
}

private def transformAsClosureBody(refs: List[Ast], baseStmtBlockAst: Ast) = {
// Determine which locals are captured
val capturedLocalNodes = baseStmtBlockAst.nodes
Expand Down Expand Up @@ -440,7 +459,7 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th
parameterAsts ++ anonProcParam,
stmtBlockAst,
methodReturnNode(node, Defines.Any),
newModifierNode(ModifierTypes.VIRTUAL) :: Nil
newModifierNode(ModifierTypes.VIRTUAL) :: newModifierNode(currentAccessModifier) :: Nil
)

_methodAst :: methodTypeDeclAst :: Nil foreach (Ast.storeInDiffGraph(_, diffGraph))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
case node: FieldsDeclaration => astsForFieldDeclarations(node)
case node: AccessModifier => registerAccessModifier(node)
case node: MethodDeclaration => astForMethodDeclaration(node)
case node: MethodAccessModifier => astForMethodAccessModifier(node)
case node: SingletonMethodDeclaration => astForSingletonMethodDeclaration(node)
case node: MultipleAssignment => node.assignments.map(astForExpression)
case node: BreakExpression => astForBreakExpression(node) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,10 @@ object RubyIntermediateAst {
def toSimpleIdentifier: SimpleIdentifier
}

sealed trait MethodAccessModifier extends AllowedTypeDeclarationChild {
def method: RubyExpression
}

final case class PublicModifier()(span: TextSpan) extends RubyExpression(span) with AccessModifier {
override def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier(None)(span)
}
Expand All @@ -487,6 +491,14 @@ object RubyIntermediateAst {
override def toSimpleIdentifier: SimpleIdentifier = SimpleIdentifier(None)(span)
}

final case class PrivateMethodModifier(method: RubyExpression)(span: TextSpan)
extends RubyExpression(span)
with MethodAccessModifier

final case class PublicMethodModifier(method: RubyExpression)(span: TextSpan)
extends RubyExpression(span)
with MethodAccessModifier

/** Represents standalone `proc { ... }` or `lambda { ... }` expressions
*/
final case class ProcOrLambdaExpr(block: Block)(span: TextSpan) extends RubyExpression(span)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,23 @@ class RubyJsonToNodeCreator(
MemberCall(lhs, ".", RubyOperators.regexpMatch, rhs :: Nil)(obj.toTextSpan)
}

private def visitMethodAccessModifier(obj: Obj): RubyExpression = {
val body = obj.visitArray(ParserKeys.Arguments) match {
case head :: Nil => head
case xs => xs.head
}

obj(ParserKeys.Name).str match {
case "public_class_method" =>
PublicMethodModifier(body)(obj.toTextSpan)
case "private_class_method" =>
PrivateMethodModifier(body)(obj.toTextSpan)
case modifierName =>
logger.warn(s"Unknown modifier type $modifierName")
defaultResult(Option(obj.toTextSpan))
}
}

private def visitMethodDefinition(obj: Obj): RubyExpression = {
val name = obj(ParserKeys.Name).str
val parameters = obj(ParserKeys.Arguments).asInstanceOf[ujson.Obj].visitArray(ParserKeys.Children)
Expand Down Expand Up @@ -870,6 +887,7 @@ class RubyJsonToNodeCreator(
case "include" => visitInclude(obj)
case "attr_reader" | "attr_writer" | "attr_accessor" => visitFieldDeclaration(obj)
case "private" | "public" | "protected" => visitAccessModifier(obj)
case "private_class_method" | "public_class_method" => visitMethodAccessModifier(obj)
case requireLike if ImportCallNames.contains(requireLike) && !hasReceiver => visitRequireLike(obj)
case _ if BinaryOperators.isBinaryOperatorName(callName) =>
val lhs = visit(obj(ParserKeys.Receiver))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.joern.rubysrc2cpg.querying
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.Main
import io.joern.rubysrc2cpg.testfixtures.RubyCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.NodeTypes
import io.shiftleft.codepropertygraph.generated.{ModifierTypes, NodeTypes}
import io.shiftleft.codepropertygraph.generated.nodes.{File, NamespaceBlock}
import io.shiftleft.semanticcpg.language.*

Expand Down Expand Up @@ -77,4 +77,46 @@ class ModuleTests extends RubyCode2CpgFixture {
case xs => fail(s"Expected one class decl, got [${xs.code.mkString(",")}]")
}
}

"Class Method Modifiers" should {
val cpg = code("""
|# Taken from Mastodon Repo
|module LanguagesHelper
| ISO_639_1 = {}
| ISO_639_3 = {}
| SUPPORTED_LOCALES = {}
| REGIONAL_LOCALE_NAMES = {}
|
| private_class_method def self.locale_name_for_sorting(locale)
| if (supported_locale = SUPPORTED_LOCALES[locale.to_sym])
| ASCIIFolding.new.fold(supported_locale[1]).downcase
| elsif (regional_locale = REGIONAL_LOCALE_NAMES[locale.to_sym])
| ASCIIFolding.new.fold(regional_locale).downcase
| else
| locale
| end
| end
|
| def publicMethodAfterwards
| end
|end
|""".stripMargin)
"Generate private modifier on method" in {
inside(cpg.method.name("locale_name_for_sorting")._modifierViaAstOut.l) {
case virtualModifier :: privateModifier :: Nil =>
virtualModifier.modifierType shouldBe ModifierTypes.VIRTUAL
privateModifier.modifierType shouldBe ModifierTypes.PRIVATE
case xs => fail(s"Expected two modifiers, got [${xs.modifierType.mkString(",")}]")
}
}

"Revert to original access modifier after previous method def" in {
inside(cpg.method.name("publicMethodAfterwards")._modifierViaAstOut.l) {
case virtualModifier :: publicModifier :: Nil =>
virtualModifier.modifierType shouldBe ModifierTypes.VIRTUAL
publicModifier.modifierType shouldBe ModifierTypes.PUBLIC
case xs => fail(s"Expected got [${xs.modifierType.mkString(",")}]")
}
}
}
}
Loading