Skip to content

Commit

Permalink
[ruby] Handle implicit multi-assignment returns (#4898)
Browse files Browse the repository at this point in the history
* Handles implicit returns of multi-assignments by returning an array of the LHS assignment targets.
* For implicit returns of multi-assignments created as desugaring of splatted parameters, returns `nil` as per what is evaluated in the Ruby interpreter.
  • Loading branch information
DavidBakerEffendi authored Sep 6, 2024
1 parent 03c9e10 commit daef3c5
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import io.joern.rubysrc2cpg.datastructures.BlockScope
import io.joern.rubysrc2cpg.passes.Defines
import io.joern.rubysrc2cpg.passes.Defines.getBuiltInType
import io.joern.x2cpg.{Ast, ValidationMode}
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, ModifierTypes, Operators}
import io.shiftleft.codepropertygraph.generated.nodes.{NewControlStructure, NewMethod, NewMethodRef, NewTypeDecl}
import io.shiftleft.codepropertygraph.generated.nodes.NewControlStructure
import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, ModifierTypes}

trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { this: AstCreator =>

Expand Down Expand Up @@ -156,6 +156,13 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
astForReturnExpression(ReturnExpression(List(node))(node.span)) :: Nil
case node: SingleAssignment =>
astForSingleAssignment(node) :: List(astForReturnExpression(ReturnExpression(List(node.lhs))(node.span)))
case node: DefaultMultipleAssignment =>
astsForStatement(node) ++ astsForImplicitReturnStatement(ArrayLiteral(node.assignments.map(_.lhs))(node.span))
case node: GroupedParameterDesugaring =>
// If the desugaring is the last expression, then we should return nil
val nilReturnSpan = node.span.spanStart("return nil")
val nilReturnLiteral = StaticLiteral(Defines.NilClass)(nilReturnSpan)
astsForStatement(node) ++ astsForImplicitReturnStatement(nilReturnLiteral)
case node: AttributeAssignment =>
List(
astForAttributeAssignment(node),
Expand All @@ -165,7 +172,6 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
case ret: ReturnExpression => astForReturnExpression(ret) :: Nil
case node: (MethodDeclaration | SingletonMethodDeclaration) =>
(astsForStatement(node) :+ astForReturnMethodDeclarationSymbolName(node)).toList
case _: BreakExpression => astsForStatement(node).toList
case node =>
logger.warn(
s"Implicit return here not supported yet: ${node.text} (${node.getClass.getSimpleName}), only generating statement"
Expand Down Expand Up @@ -194,10 +200,6 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
returnAst(returnNode(node, code(node)), List(astForMemberAccess(node)))
}

private def astForReturnMemberCall(node: MemberCall): Ast = {
returnAst(returnNode(node, code(node)), List(astForMemberCall(node)))
}

protected def astForBreakExpression(node: BreakExpression): Ast = {
val _node = NewControlStructure()
.controlStructureType(ControlStructureTypes.BREAK)
Expand Down Expand Up @@ -245,7 +247,6 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
case StatementList(statements) => StatementList(statementListReturningLastExpression(statements))(x.span)
case clause: ControlFlowClause => clauseReturningLastExpression(clause)
case node: ControlFlowStatement => transform(node)
case node: BreakExpression => node
case node: ReturnExpression => node
case _ => ReturnExpression(x :: Nil)(x.span)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,12 @@ object RubyIntermediateAst {
extends RubyExpression(span)
with MethodParameter

final case class GroupedParameter(name: String, tmpParam: RubyExpression, multipleAssignment: RubyExpression)(
span: TextSpan
) extends RubyExpression(span)
final case class GroupedParameter(
name: String,
tmpParam: RubyExpression,
multipleAssignment: GroupedParameterDesugaring
)(span: TextSpan)
extends RubyExpression(span)
with MethodParameter

sealed trait CollectionParameter extends MethodParameter
Expand All @@ -193,9 +196,17 @@ object RubyIntermediateAst {
extends RubyExpression(span)
with RubyStatement

final case class MultipleAssignment(assignments: List[SingleAssignment])(span: TextSpan)
trait MultipleAssignment extends RubyStatement {
def assignments: List[SingleAssignment]
}

final case class DefaultMultipleAssignment(assignments: List[SingleAssignment])(span: TextSpan)
extends RubyExpression(span)
with RubyStatement
with MultipleAssignment

final case class GroupedParameterDesugaring(assignments: List[SingleAssignment])(span: TextSpan)
extends RubyExpression(span)
with MultipleAssignment

final case class SplattingRubyNode(target: RubyExpression)(span: TextSpan) extends RubyExpression(span)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ class RubyNodeCreator(variableNameGen: FreshNameGenerator[String] = FreshNameGen
GroupedParameter(
tmpMandatoryParam.span.text,
tmpMandatoryParam,
MultipleAssignment(singleAssignments.toList)(ctx.toTextSpan)
GroupedParameterDesugaring(singleAssignments)(ctx.toTextSpan)
)(ctx.toTextSpan)
}

Expand Down Expand Up @@ -607,7 +607,7 @@ class RubyNodeCreator(variableNameGen: FreshNameGenerator[String] = FreshNameGen
} else {
defaultAssignments
}
MultipleAssignment(assignments)(ctx.toTextSpan)
DefaultMultipleAssignment(assignments)(ctx.toTextSpan)
}

override def visitMultipleLeftHandSide(ctx: RubyParser.MultipleLeftHandSideContext): RubyExpression = {
Expand All @@ -627,7 +627,7 @@ class RubyNodeCreator(variableNameGen: FreshNameGenerator[String] = FreshNameGen

override def visitPackingLeftHandSide(ctx: RubyParser.PackingLeftHandSideContext): RubyExpression = {
val splatNode = Option(ctx.leftHandSide()) match {
case Some(lhs) => SplattingRubyNode(visit(ctx.leftHandSide))(ctx.toTextSpan)
case Some(lhs) => SplattingRubyNode(visit(lhs))(ctx.toTextSpan)
case None =>
SplattingRubyNode(MandatoryParameter("_")(ctx.toTextSpan.spanStart("_")))(ctx.toTextSpan.spanStart("*_"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,25 @@ class DestructuredAssignmentsTests extends RubyCode2CpgFixture {
case xs => fail(s"Expected 3 assignments, got ${xs.code.mkString(",")}")
}
}

"multi-assignments as a return value" should {

val cpg = code("""
|def f
| a, b = 1, 2 # => return [1, 2]
|end
|""".stripMargin)

"create an explicit return of the LHS values as an array" in {
val arrayLiteral = cpg.method.name("f").methodReturn.toReturn.astChildren.isCall.head

arrayLiteral.name shouldBe Operators.arrayInitializer
arrayLiteral.methodFullName shouldBe Operators.arrayInitializer
arrayLiteral.code shouldBe "a, b = 1, 2"

arrayLiteral.astChildren.isIdentifier.code.l shouldBe List("a", "b")
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -474,5 +474,10 @@ class DoBlockTests extends RubyCode2CpgFixture {
case xs => fail(s"Expected 4 assignments, got [${xs.code.mkString(", ")}]")
}
}

"Return nil and not the desugaring" in {
val nilLiteral = cpg.method.isLambda.methodReturn.toReturn.astChildren.isLiteral.head
nilLiteral.code shouldBe "return nil"
}
}
}

0 comments on commit daef3c5

Please sign in to comment.