Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] Add If and MakeStruct Simplifications #14231

Merged
merged 14 commits into from
Feb 10, 2024
33 changes: 33 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Simplify.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import is.hail.types.tcoerce
import is.hail.types.virtual._
import is.hail.utils._

import scala.collection.mutable

object Simplify {

/** Transform 'ir' using simplification rules until none apply. */
Expand Down Expand Up @@ -248,6 +250,8 @@ object Simplify {
else
If(IsNA(c), NA(cnsq.typ), cnsq)

case If(IsNA(a), NA(_), b) if a == b => b

case If(ApplyUnaryPrimOp(Bang, c), cnsq, altr) => If(c, altr, cnsq)

case If(c1, If(c2, cnsq2, _), altr1) if c1 == c2 => If(c1, cnsq2, altr1)
Expand Down Expand Up @@ -546,6 +550,10 @@ object Simplify {
selectFields.filter(f => !insertNames.contains(f)) ++ oldFields.map(_._1)
InsertFields(SelectFields(struct, preservedFields), newFields, Some(fields.toFastSeq))

case MakeStructOfGetField(o, newNames) =>
val select = SelectFields(o, newNames.map(_._1))
CastRename(select, select.typ.asInstanceOf[TStruct].rename(newNames.toMap))

case GetTupleElement(MakeTuple(xs), idx) => xs.find(_._1 == idx).get._2

case TableCount(MatrixColsTable(child)) if child.columnCount.isDefined =>
Expand Down Expand Up @@ -1350,4 +1358,29 @@ object Simplify {
typ.blockSize,
)
}

// Match on expressions of the form
// MakeStruct(IndexedSeq(a -> GetField(o, x) [, b -> GetField(o, y), ...]))
// where
// - all fields are extracted from the same object, `o`
// - all references to the fields in o are unique
private object MakeStructOfGetField {
def unapply(ir: IR): Option[(IR, IndexedSeq[(String, String)])] =
ir match {
case MakeStruct(fields) if fields.nonEmpty =>
val names = mutable.HashSet.empty[String]
val rewrites = new BoxedArrayBuilder[(String, String)](fields.length)

fields.view.map {
case (a, GetField(o, b)) if names.add(b) =>
rewrites += (b -> a)
Some(o)
case _ => None
}
.reduce((a, b) => if (a == b) a else None)
.map(_ -> rewrites.underlying().toFastSeq)
case _ =>
None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,41 @@ import is.hail.types.physical.stypes.{EmitType, SType, SValue}
import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue}
import is.hail.types.virtual.{TStruct, Type}

final case class SSubsetStruct(parent: SBaseStruct, fieldNames: IndexedSeq[String])
extends SBaseStruct {
case class SSubsetStruct(parent: SBaseStruct, fieldNames: IndexedSeq[String]) extends SBaseStruct {

override val size: Int = fieldNames.size

val _fieldIdx: Map[String, Int] = fieldNames.zipWithIndex.toMap

val newToOldFieldMapping: Map[Int, Int] = _fieldIdx
lazy val newToOldFieldMapping: Map[Int, Int] = _fieldIdx
.map { case (f, i) => (i, parent.virtualType.asInstanceOf[TStruct].fieldIdx(f)) }

override val fieldTypes: IndexedSeq[SType] =
override lazy val fieldTypes: IndexedSeq[SType] =
Array.tabulate(size)(i => parent.fieldTypes(newToOldFieldMapping(i)))

override val fieldEmitTypes: IndexedSeq[EmitType] =
override lazy val fieldEmitTypes: IndexedSeq[EmitType] =
Array.tabulate(size)(i => parent.fieldEmitTypes(newToOldFieldMapping(i)))

override lazy val virtualType: TStruct = {
val vparent = parent.virtualType.asInstanceOf[TStruct]
TStruct(fieldNames.map(f => (f, vparent.field(f).typ)): _*)
}

override def fieldIdx(fieldName: String): Int = _fieldIdx(fieldName)
override def fieldIdx(fieldName: String): Int =
_fieldIdx(fieldName)

override def castRename(t: Type): SType = {
val renamedVType = t.asInstanceOf[TStruct]
val newNames = renamedVType.fieldNames
val subsetPrevVirtualType = virtualType
val vparent = parent.virtualType.asInstanceOf[TStruct]
val newParent = TStruct(vparent.fieldNames.map(f =>
subsetPrevVirtualType.fieldIdx.get(f) match {
case Some(idxInSelectedFields) =>
val renamed = renamedVType.fields(idxInSelectedFields)
(renamed.name, renamed.typ)
case None => (f, vparent.fieldType(f))
}
): _*)
val newType = SSubsetStruct(parent.castRename(newParent).asInstanceOf[SBaseStruct], newNames)
assert(newType.virtualType == t)
newType
new SSubsetStruct(parent, renamedVType.fieldNames) {
override lazy val newToOldFieldMapping: Map[Int, Int] =
SSubsetStruct.this.newToOldFieldMapping
override lazy val fieldTypes: IndexedSeq[SType] =
SSubsetStruct.this.fieldTypes
override lazy val fieldEmitTypes: IndexedSeq[EmitType] =
SSubsetStruct.this.fieldEmitTypes
override lazy val virtualType: TStruct =
renamedVType
}
}

override def _coerceOrCopy(
Expand Down
12 changes: 10 additions & 2 deletions hail/src/main/scala/is/hail/types/virtual/TStruct.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import is.hail.expr.ir.{Env, IRParser, IntArrayBuilder}
import is.hail.utils._

import scala.collection.JavaConverters._

import org.apache.spark.sql.Row
import org.json4s.CustomSerializer
import org.json4s.JsonAST.JString

import scala.collection.mutable

class TStructSerializer extends CustomSerializer[TStruct](format =>
(
{ case JString(s) => IRParser.parseStructType(s) },
Expand Down Expand Up @@ -52,7 +53,14 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct {

lazy val types: Array[Type] = fields.map(_.typ).toArray

lazy val fieldNames: Array[String] = fields.map(_.name).toArray
val fieldNames: Array[String] = {
val seen = mutable.Set.empty[String]
fields.toArray.map { f =>
val name = f.name
assert(seen.add(name), f"duplicate name '$name' found in '${_toPretty}'.")
name
}
}

def size: Int = fields.length

Expand Down
60 changes: 58 additions & 2 deletions hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package is.hail.expr.ir

import is.hail.{ExecStrategy, HailSuite}
import is.hail.expr.ir.TestUtils.IRAggCount
import is.hail.types.virtual._
import is.hail.utils.{FastSeq, Interval}
import is.hail.variant.Locus

import is.hail.{ExecStrategy, HailSuite}
import org.apache.spark.sql.Row
import org.scalatest.Matchers.{be, convertToAnyShouldWrapper}
import org.testng.annotations.{BeforeMethod, DataProvider, Test}
Expand Down Expand Up @@ -747,4 +746,61 @@ class SimplifySuite extends HailSuite {
def testTestSwitchSimplification(x: IR, default: IR, cases: IndexedSeq[IR], expected: Any): Unit =
assert(Simplify(ctx, Switch(x, default, cases)) == expected)

@DataProvider(name = "IfRules")
def ifRules: Array[Array[Any]] = {
val x = Ref(genUID(), TInt32)
val y = Ref(genUID(), TInt32)
val c = Ref(genUID(), TBoolean)

Array(
Array(True(), x, Die("Failure", x.typ), x),
Array(False(), Die("Failure", x.typ), x, x),
Array(IsNA(x), NA(x.typ), x, x),
Array(ApplyUnaryPrimOp(Bang, c), x, y, If(c, y, x)),
Array(c, If(c, x, y), y, If(c, x, y)),
Array(c, x, If(c, x, y), If(c, x, y)),
Array(c, x, x, If(IsNA(c), NA(x.typ), x)),
)
}

@Test(dataProvider = "IfRules")
def testIfSimplification(pred: IR, cnsq: IR, altr: IR, expected: Any): Unit =
assert(Simplify(ctx, If(pred, cnsq, altr)) == expected)

@DataProvider(name = "MakeStructRules")
def makeStructRules: Array[Array[Any]] = {
val s = ref(TStruct(
"a" -> TInt32,
"b" -> TInt64,
"c" -> TFloat32,
))

def get(name: String) = GetField(s, name)

Array(
Array(
FastSeq("x" -> get("a")),
CastRename(SelectFields(s, FastSeq("a")), TStruct("x" -> TInt32)),
),
Array(
FastSeq("x" -> get("a"), "y" -> get("b")),
CastRename(SelectFields(s, FastSeq("a", "b")), TStruct("x" -> TInt32, "y" -> TInt64)),
),
Array(
FastSeq("a" -> get("a"), "b" -> get("b")),
SelectFields(s, FastSeq("a", "b")),
),
Array(
FastSeq("a" -> get("a"), "b" -> get("b"), "c" -> get("c")),
s,
),
)
}

@Test(dataProvider = "MakeStructRules")
def testMakeStruct(fields: IndexedSeq[(String, IR)], expected: IR): Unit = {
val x = Simplify(ctx, MakeStruct(fields))
assert(x == expected)
}

}
Loading