Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] Add If and MakeStruct Simplifications #14231

Merged
merged 14 commits into from
Feb 10, 2024
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,7 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
iec.map(cb)(pc => cast(cb, pc))
case CastRename(v, _typ) =>
emitI(v)
.map(cb)(pc => pc.st.castRename(_typ).fromValues(pc.valueTuple))
.map(cb)(_.castRename(_typ))
case NA(typ) =>
IEmitCode.missing(cb, SUnreachable.fromVirtualType(typ).defaultValue)
case IsNA(v) =>
Expand Down
33 changes: 33 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Simplify.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import is.hail.types.tcoerce
import is.hail.types.virtual._
import is.hail.utils._

import scala.collection.mutable

object Simplify {

/** Transform 'ir' using simplification rules until none apply. */
Expand Down Expand Up @@ -248,6 +250,8 @@ object Simplify {
else
If(IsNA(c), NA(cnsq.typ), cnsq)

case If(IsNA(a), NA(_), b) if a == b => b

case If(ApplyUnaryPrimOp(Bang, c), cnsq, altr) => If(c, altr, cnsq)

case If(c1, If(c2, cnsq2, _), altr1) if c1 == c2 => If(c1, cnsq2, altr1)
Expand Down Expand Up @@ -546,6 +550,10 @@ object Simplify {
selectFields.filter(f => !insertNames.contains(f)) ++ oldFields.map(_._1)
InsertFields(SelectFields(struct, preservedFields), newFields, Some(fields.toFastSeq))

case MakeStructOfGetField(o, newNames) =>
val select = SelectFields(o, newNames.map(_._1))
CastRename(select, select.typ.asInstanceOf[TStruct].rename(newNames.toMap))

case GetTupleElement(MakeTuple(xs), idx) => xs.find(_._1 == idx).get._2

case TableCount(MatrixColsTable(child)) if child.columnCount.isDefined =>
Expand Down Expand Up @@ -1350,4 +1358,29 @@ object Simplify {
typ.blockSize,
)
}

// Match on expressions of the form
// MakeStruct(IndexedSeq(a -> GetField(o, x) [, b -> GetField(o, y), ...]))
// where
// - all fields are extracted from the same object, `o`
// - all references to the fields in o are unique
private object MakeStructOfGetField {
def unapply(ir: IR): Option[(IR, IndexedSeq[(String, String)])] =
ir match {
case MakeStruct(fields) if fields.nonEmpty =>
val names = mutable.HashSet.empty[String]
val rewrites = new BoxedArrayBuilder[(String, String)](fields.length)

fields.view.map {
case (a, GetField(o, b)) if names.add(b) =>
rewrites += (b -> a)
Some(o)
case _ => None
}
.reduce((a, b) => if (a == b) a else None)
.map(_ -> rewrites.underlying().toFastSeq)
case _ =>
None
}
}
}
47 changes: 28 additions & 19 deletions hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2378,39 +2378,48 @@ object LowerTableIR {
lowered
}

// format: off

/* We have a couple of options when repartitioning a table:
* 1. Send only the contexts needed to compute each new partition and take/drop the rows that fall
* in that partition.
* 2. Compute the table with the old partitioner, write the table to cloud storage then read the
* new partitions from the index.
* 1. Send only the contexts needed to compute each new partition and
* take/drop the rows that fall in that partition.
* 2. Compute the table with the old partitioner, write the table to cloud
* storage then read the new partitions from the index.
*
* We'd like to do 1 as keeping things in memory (with perhaps a bit of work duplication) is
* generally less expensive than writing and reading a table to and from cloud storage. There
* comes a cross-over point, however, where it's cheaper to do the latter. One such example is as
* follows: consider a repartitioning where the same context is used to compute multiple
* partitions. The (parallel) computation of each partition involves at least all of the work to
* compute the previous partition:
* We'd like to do 1 as keeping things in memory (with perhaps a bit of work
* duplication) is generally less expensive than writing and reading a table
* to and from cloud storage. There comes a cross-over point, however, where
* it's cheaper to do the latter. One such example is as follows: consider a
* repartitioning where the same context is used to compute multiple
* partitions. The (parallel) computation of each partition involves at least
* all of the work to compute the previous partition:
*
* *----------------------* in: | | ...
* *----------------------* / | \ / | \
* *--* *---* *--* out: | | | | ... | |
* *--* *---* *--*
* *----------------------*
* in: | | ...
* *----------------------*
* / | \
* / | \
* *--* *---* *--*
* out: | | | | ... | |
* *--* *---* *--*
*
* We can estimate the relative cost of computing the new partitions vs spilling as being
* proportional to the mean number of old partitions used to compute new partitions. */
* We can estimate the relative cost of computing the new partitions vs
* spilling as being proportional to the mean number of old partitions
* used to compute new partitions.
*/
def isRepartitioningCheap(original: RVDPartitioner, planned: RVDPartitioner): Boolean = {
val cost =
if (original.numPartitions == 0)
0.0
else
(0.0167 / original.numPartitions) * planned
.rangeBounds
.map { intrvl =>
val (lo, hi) = original.intervalRange(intrvl); hi - lo
}
.map { intrvl => val (lo, hi) = original.intervalRange(intrvl); hi - lo }
.sum

log.info(s"repartition cost: $cost")
cost <= 1.0
}

// format: on
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ abstract class PBaseStruct extends PType {

def size: Int = fields.length

def isIsomorphicTo(other: PBaseStruct) =
def isIsomorphicTo(other: PBaseStruct): Boolean =
this.fields.size == other.fields.size && this.isCompatibleWith(other)

def _toPretty: String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import is.hail.asm4s.{Code, Value}
import is.hail.backend.HailStateManager
import is.hail.expr.ir.EmitCodeBuilder
import is.hail.types.physical.stypes.SValue
import is.hail.types.physical.stypes.concrete.SSubsetStruct
import is.hail.types.physical.stypes.concrete.SStructView
import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructValue}
import is.hail.types.virtual.TStruct
import is.hail.utils._
Expand Down Expand Up @@ -137,7 +137,8 @@ final case class PSubsetStruct(ps: PStruct, _fieldNames: IndexedSeq[String]) ext
): Long =
throw new UnsupportedOperationException

def sType: SSubsetStruct = SSubsetStruct(ps.sType.asInstanceOf[SBaseStruct], _fieldNames)
def sType: SBaseStruct =
SStructView.subset(_fieldNames, ps.sType)

def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean)
: Value[Long] =
Expand Down
4 changes: 4 additions & 0 deletions hail/src/main/scala/is/hail/types/physical/stypes/SCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import is.hail.expr.ir.EmitCodeBuilder
import is.hail.types.physical.stypes.concrete.SRNGStateValue
import is.hail.types.physical.stypes.interfaces._
import is.hail.types.physical.stypes.primitives._
import is.hail.types.virtual.Type

object SCode {
def add(cb: EmitCodeBuilder, left: SValue, right: SValue, required: Boolean): SValue = {
Expand Down Expand Up @@ -114,6 +115,9 @@ trait SValue {
throw new UnsupportedOperationException(s"Stype $st has no hashcode")

def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value

def castRename(t: Type): SValue =
st.castRename(t).fromValues(valueTuple)
}

trait SSettable extends SValue {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class SStackStructValue(val st: SStackStruct, val values: IndexedSeq[EmitValue])
override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] =
values(fieldIdx).m

override def subset(fieldNames: String*): SStackStructValue = {
override def subset(fieldNames: String*): SBaseStructValue = {
val newToOld = fieldNames.map(st.fieldIdx).toArray
val oldVType = st.virtualType.asInstanceOf[TStruct]
val newVirtualType = TStruct(newToOld.map(i => (oldVType.fieldNames(i), oldVType.types(i))): _*)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package is.hail.types.physical.stypes.concrete

import is.hail.annotations.Region
import is.hail.asm4s.{Settable, TypeInfo, Value}
import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode}
import is.hail.types.physical.{PCanonicalStruct, PType}
import is.hail.types.physical.stypes.{EmitType, SType, SValue}
import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue}
import is.hail.types.virtual.{TBaseStruct, TStruct, Type}

object SStructView {
def subset(fieldnames: IndexedSeq[String], struct: SBaseStruct): SStructView =
struct match {
case s: SStructView =>
val pfields = s.parent.virtualType.fields
new SStructView(
s.parent,
fieldnames.map(f => pfields(s.newToOldFieldMapping(s.fieldIdx(f))).name),
s.rename.typeAfterSelectNames(fieldnames),
)

case s =>
val restrict = s.virtualType.asInstanceOf[TStruct].typeAfterSelectNames(fieldnames)
new SStructView(s, fieldnames, restrict)
}
}

// A 'view' on `SBaseStruct`s, ie one that presents an upcast and/or renamed facade on another
final class SStructView(
private val parent: SBaseStruct,
private val restrict: IndexedSeq[String],
private val rename: TStruct,
) extends SBaseStruct {

assert(
parent.virtualType.asInstanceOf[TStruct].typeAfterSelectNames(restrict) canCastTo rename,
s"""Renamed type is not isomorphic to subsetted type
| parent: '${parent.virtualType._toPretty}'
| restrict: '${restrict.mkString("[", ",", "]")}'
| rename: '${rename._toPretty}'
|""".stripMargin,
)

override def size: Int =
restrict.length

lazy val newToOldFieldMapping: Map[Int, Int] =
restrict.view.zipWithIndex.map { case (f, i) => i -> parent.fieldIdx(f) }.toMap

override lazy val fieldTypes: IndexedSeq[SType] =
Array.tabulate(size) { i =>
parent
.fieldTypes(newToOldFieldMapping(i))
.castRename(rename.fields(i).typ)
}
Comment on lines +50 to +55
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this assumes that the fields of restrict have the same virtual type as the corresponding fields in parent. I'm in favor of not making SStructView support deep subsetting for now to keep things simple, but in that case, having an entire restrict type is redundant and, I think, confusing. It would be sufficient to just use a list of parent fields to subset to (but then still the full rename struct type).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


override lazy val fieldEmitTypes: IndexedSeq[EmitType] =
Array.tabulate(size) { i =>
parent
.fieldEmitTypes(newToOldFieldMapping(i))
.copy(st = fieldTypes(i))
}

override def virtualType: TBaseStruct =
rename

override def fieldIdx(fieldName: String): Int =
rename.fieldIdx(fieldName)

override def castRename(t: Type): SType =
new SStructView(parent, restrict, rename = t.asInstanceOf[TStruct])

override def _coerceOrCopy(
cb: EmitCodeBuilder,
region: Value[Region],
value: SValue,
deepCopy: Boolean,
): SValue = {
if (deepCopy)
throw new NotImplementedError("Deep copy on struct view")

value.st match {
case s: SStructView if this == s && !deepCopy =>
value
}
}

override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] =
parent.settableTupleTypes()

override def fromSettables(settables: IndexedSeq[Settable[_]]): SStructViewSettable =
new SStructViewSettable(
this,
parent.fromSettables(settables).asInstanceOf[SBaseStructSettable],
)

override def fromValues(values: IndexedSeq[Value[_]]): SStructViewValue =
new SStructViewValue(this, parent.fromValues(values).asInstanceOf[SBaseStructValue])

override def copiedType: SType =
if (virtualType.size < 64)
SStackStruct(virtualType, fieldEmitTypes.map(_.copiedType))
else {
val ct = SBaseStructPointer(storageType().asInstanceOf[PCanonicalStruct])
assert(ct.virtualType == virtualType, s"ct=$ct, this=$this")
ct
}

def storageType(): PType = {
val pt = PCanonicalStruct(
required = false,
args = rename.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _*,
)
assert(pt.virtualType == virtualType, s"pt=$pt, this=$this")
pt
}

// aspirational implementation
// def storageType(): PType = StoredSTypePType(this, false)

override def containsPointers: Boolean =
parent.containsPointers

override def equals(obj: Any): Boolean =
obj match {
case s: SStructView =>
rename == s.rename &&
newToOldFieldMapping == s.newToOldFieldMapping &&
parent == s.parent // todo test isIsomorphicTo
case _ =>
false
}
}

class SStructViewValue(val st: SStructView, val prev: SBaseStructValue) extends SBaseStructValue {

override lazy val valueTuple: IndexedSeq[Value[_]] =
prev.valueTuple

override def subset(fieldNames: String*): SBaseStructValue =
new SStructViewValue(SStructView.subset(fieldNames.toIndexedSeq, st), prev)

override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode =
prev
.loadField(cb, st.newToOldFieldMapping(fieldIdx))
.map(cb)(_.castRename(st.virtualType.fields(fieldIdx).typ))

override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] =
prev.isFieldMissing(cb, st.newToOldFieldMapping(fieldIdx))
}

final class SStructViewSettable(st: SStructView, prev: SBaseStructSettable)
extends SStructViewValue(st, prev) with SBaseStructSettable {
override def subset(fieldNames: String*): SBaseStructValue =
new SStructViewSettable(SStructView.subset(fieldNames.toIndexedSeq, st), prev)

override def settableTuple(): IndexedSeq[Settable[_]] =
prev.settableTuple()

override def store(cb: EmitCodeBuilder, pv: SValue): Unit =
prev.store(cb, pv.asInstanceOf[SStructViewValue].prev)
}
Loading