Skip to content

Commit

Permalink
[compiler] Add If and MakeStruct Simplifications (#14231)
Browse files Browse the repository at this point in the history
Add the following transformations:
```scala
MakeStruct("a" -> GetField(o, "x"), ...) -> CastRename(SelectFields(o, ["x"..]), newtype)
If(IsNA(x), NA(x.typ), x) -> x
```

The changes to `SStructView` (nee `SSubsetStruct`) were as a result of a
bud that prevented subsetting and then renaming to an excluded field, ie
`{x, y, z} subset {z} rename {x}`. Now `SStructView` leaves its parent
`SType` unmodified and casts loads through the parent to the appropriate
`SValue`.
  • Loading branch information
ehigham committed Feb 10, 2024
1 parent faae93c commit a515ccd
Show file tree
Hide file tree
Showing 13 changed files with 363 additions and 153 deletions.
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,7 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {
iec.map(cb)(pc => cast(cb, pc))
case CastRename(v, _typ) =>
emitI(v)
.map(cb)(pc => pc.st.castRename(_typ).fromValues(pc.valueTuple))
.map(cb)(_.castRename(_typ))
case NA(typ) =>
IEmitCode.missing(cb, SUnreachable.fromVirtualType(typ).defaultValue)
case IsNA(v) =>
Expand Down
33 changes: 33 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Simplify.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import is.hail.types.tcoerce
import is.hail.types.virtual._
import is.hail.utils._

import scala.collection.mutable

object Simplify {

/** Transform 'ir' using simplification rules until none apply. */
Expand Down Expand Up @@ -248,6 +250,8 @@ object Simplify {
else
If(IsNA(c), NA(cnsq.typ), cnsq)

case If(IsNA(a), NA(_), b) if a == b => b

case If(ApplyUnaryPrimOp(Bang, c), cnsq, altr) => If(c, altr, cnsq)

case If(c1, If(c2, cnsq2, _), altr1) if c1 == c2 => If(c1, cnsq2, altr1)
Expand Down Expand Up @@ -546,6 +550,10 @@ object Simplify {
selectFields.filter(f => !insertNames.contains(f)) ++ oldFields.map(_._1)
InsertFields(SelectFields(struct, preservedFields), newFields, Some(fields.toFastSeq))

case MakeStructOfGetField(o, newNames) =>
val select = SelectFields(o, newNames.map(_._1))
CastRename(select, select.typ.asInstanceOf[TStruct].rename(newNames.toMap))

case GetTupleElement(MakeTuple(xs), idx) => xs.find(_._1 == idx).get._2

case TableCount(MatrixColsTable(child)) if child.columnCount.isDefined =>
Expand Down Expand Up @@ -1350,4 +1358,29 @@ object Simplify {
typ.blockSize,
)
}

// Match on expressions of the form
// MakeStruct(IndexedSeq(a -> GetField(o, x) [, b -> GetField(o, y), ...]))
// where
// - all fields are extracted from the same object, `o`
// - all references to the fields in o are unique
private object MakeStructOfGetField {
def unapply(ir: IR): Option[(IR, IndexedSeq[(String, String)])] =
ir match {
case MakeStruct(fields) if fields.nonEmpty =>
val names = mutable.HashSet.empty[String]
val rewrites = new BoxedArrayBuilder[(String, String)](fields.length)

fields.view.map {
case (a, GetField(o, b)) if names.add(b) =>
rewrites += (b -> a)
Some(o)
case _ => None
}
.reduce((a, b) => if (a == b) a else None)
.map(_ -> rewrites.underlying().toFastSeq)
case _ =>
None
}
}
}
47 changes: 28 additions & 19 deletions hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2378,39 +2378,48 @@ object LowerTableIR {
lowered
}

// format: off

/* We have a couple of options when repartitioning a table:
* 1. Send only the contexts needed to compute each new partition and take/drop the rows that fall
* in that partition.
* 2. Compute the table with the old partitioner, write the table to cloud storage then read the
* new partitions from the index.
* 1. Send only the contexts needed to compute each new partition and
* take/drop the rows that fall in that partition.
* 2. Compute the table with the old partitioner, write the table to cloud
* storage then read the new partitions from the index.
*
* We'd like to do 1 as keeping things in memory (with perhaps a bit of work duplication) is
* generally less expensive than writing and reading a table to and from cloud storage. There
* comes a cross-over point, however, where it's cheaper to do the latter. One such example is as
* follows: consider a repartitioning where the same context is used to compute multiple
* partitions. The (parallel) computation of each partition involves at least all of the work to
* compute the previous partition:
* We'd like to do 1 as keeping things in memory (with perhaps a bit of work
* duplication) is generally less expensive than writing and reading a table
* to and from cloud storage. There comes a cross-over point, however, where
* it's cheaper to do the latter. One such example is as follows: consider a
* repartitioning where the same context is used to compute multiple
* partitions. The (parallel) computation of each partition involves at least
* all of the work to compute the previous partition:
*
* *----------------------* in: | | ...
* *----------------------* / | \ / | \
* *--* *---* *--* out: | | | | ... | |
* *--* *---* *--*
* *----------------------*
* in: | | ...
* *----------------------*
* / | \
* / | \
* *--* *---* *--*
* out: | | | | ... | |
* *--* *---* *--*
*
* We can estimate the relative cost of computing the new partitions vs spilling as being
* proportional to the mean number of old partitions used to compute new partitions. */
* We can estimate the relative cost of computing the new partitions vs
* spilling as being proportional to the mean number of old partitions
* used to compute new partitions.
*/
def isRepartitioningCheap(original: RVDPartitioner, planned: RVDPartitioner): Boolean = {
val cost =
if (original.numPartitions == 0)
0.0
else
(0.0167 / original.numPartitions) * planned
.rangeBounds
.map { intrvl =>
val (lo, hi) = original.intervalRange(intrvl); hi - lo
}
.map { intrvl => val (lo, hi) = original.intervalRange(intrvl); hi - lo }
.sum

log.info(s"repartition cost: $cost")
cost <= 1.0
}

// format: on
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ abstract class PBaseStruct extends PType {

def size: Int = fields.length

def isIsomorphicTo(other: PBaseStruct) =
def isIsomorphicTo(other: PBaseStruct): Boolean =
this.fields.size == other.fields.size && this.isCompatibleWith(other)

def _toPretty: String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import is.hail.asm4s.{Code, Value}
import is.hail.backend.HailStateManager
import is.hail.expr.ir.EmitCodeBuilder
import is.hail.types.physical.stypes.SValue
import is.hail.types.physical.stypes.concrete.SSubsetStruct
import is.hail.types.physical.stypes.concrete.SStructView
import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructValue}
import is.hail.types.virtual.TStruct
import is.hail.utils._
Expand Down Expand Up @@ -137,7 +137,8 @@ final case class PSubsetStruct(ps: PStruct, _fieldNames: IndexedSeq[String]) ext
): Long =
throw new UnsupportedOperationException

def sType: SSubsetStruct = SSubsetStruct(ps.sType.asInstanceOf[SBaseStruct], _fieldNames)
def sType: SBaseStruct =
SStructView.subset(_fieldNames, ps.sType)

def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean)
: Value[Long] =
Expand Down
4 changes: 4 additions & 0 deletions hail/src/main/scala/is/hail/types/physical/stypes/SCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import is.hail.expr.ir.EmitCodeBuilder
import is.hail.types.physical.stypes.concrete.SRNGStateValue
import is.hail.types.physical.stypes.interfaces._
import is.hail.types.physical.stypes.primitives._
import is.hail.types.virtual.Type

object SCode {
def add(cb: EmitCodeBuilder, left: SValue, right: SValue, required: Boolean): SValue = {
Expand Down Expand Up @@ -114,6 +115,9 @@ trait SValue {
throw new UnsupportedOperationException(s"Stype $st has no hashcode")

def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value

def castRename(t: Type): SValue =
st.castRename(t).fromValues(valueTuple)
}

trait SSettable extends SValue {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class SStackStructValue(val st: SStackStruct, val values: IndexedSeq[EmitValue])
override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] =
values(fieldIdx).m

override def subset(fieldNames: String*): SStackStructValue = {
override def subset(fieldNames: String*): SBaseStructValue = {
val newToOld = fieldNames.map(st.fieldIdx).toArray
val oldVType = st.virtualType.asInstanceOf[TStruct]
val newVirtualType = TStruct(newToOld.map(i => (oldVType.fieldNames(i), oldVType.types(i))): _*)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package is.hail.types.physical.stypes.concrete

import is.hail.annotations.Region
import is.hail.asm4s.{Settable, TypeInfo, Value}
import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode}
import is.hail.types.physical.{PCanonicalStruct, PType}
import is.hail.types.physical.stypes.{EmitType, SType, SValue}
import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue}
import is.hail.types.virtual.{TBaseStruct, TStruct, Type}

object SStructView {
def subset(fieldnames: IndexedSeq[String], struct: SBaseStruct): SStructView =
struct match {
case s: SStructView =>
val pfields = s.parent.virtualType.fields
new SStructView(
s.parent,
fieldnames.map(f => pfields(s.newToOldFieldMapping(s.fieldIdx(f))).name),
s.rename.typeAfterSelectNames(fieldnames),
)

case s =>
val restrict = s.virtualType.asInstanceOf[TStruct].typeAfterSelectNames(fieldnames)
new SStructView(s, fieldnames, restrict)
}
}

// A 'view' on `SBaseStruct`s, ie one that presents an upcast and/or renamed facade on another
final class SStructView(
private val parent: SBaseStruct,
private val restrict: IndexedSeq[String],
private val rename: TStruct,
) extends SBaseStruct {

assert(
parent.virtualType.asInstanceOf[TStruct].typeAfterSelectNames(restrict) canCastTo rename,
s"""Renamed type is not isomorphic to subsetted type
| parent: '${parent.virtualType._toPretty}'
| restrict: '${restrict.mkString("[", ",", "]")}'
| rename: '${rename._toPretty}'
|""".stripMargin,
)

override def size: Int =
restrict.length

lazy val newToOldFieldMapping: Map[Int, Int] =
restrict.view.zipWithIndex.map { case (f, i) => i -> parent.fieldIdx(f) }.toMap

override lazy val fieldTypes: IndexedSeq[SType] =
Array.tabulate(size) { i =>
parent
.fieldTypes(newToOldFieldMapping(i))
.castRename(rename.fields(i).typ)
}

override lazy val fieldEmitTypes: IndexedSeq[EmitType] =
Array.tabulate(size) { i =>
parent
.fieldEmitTypes(newToOldFieldMapping(i))
.copy(st = fieldTypes(i))
}

override def virtualType: TBaseStruct =
rename

override def fieldIdx(fieldName: String): Int =
rename.fieldIdx(fieldName)

override def castRename(t: Type): SType =
new SStructView(parent, restrict, rename = t.asInstanceOf[TStruct])

override def _coerceOrCopy(
cb: EmitCodeBuilder,
region: Value[Region],
value: SValue,
deepCopy: Boolean,
): SValue = {
if (deepCopy)
throw new NotImplementedError("Deep copy on struct view")

value.st match {
case s: SStructView if this == s && !deepCopy =>
value
}
}

override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] =
parent.settableTupleTypes()

override def fromSettables(settables: IndexedSeq[Settable[_]]): SStructViewSettable =
new SStructViewSettable(
this,
parent.fromSettables(settables).asInstanceOf[SBaseStructSettable],
)

override def fromValues(values: IndexedSeq[Value[_]]): SStructViewValue =
new SStructViewValue(this, parent.fromValues(values).asInstanceOf[SBaseStructValue])

override def copiedType: SType =
if (virtualType.size < 64)
SStackStruct(virtualType, fieldEmitTypes.map(_.copiedType))
else {
val ct = SBaseStructPointer(storageType().asInstanceOf[PCanonicalStruct])
assert(ct.virtualType == virtualType, s"ct=$ct, this=$this")
ct
}

def storageType(): PType = {
val pt = PCanonicalStruct(
required = false,
args = rename.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _*,
)
assert(pt.virtualType == virtualType, s"pt=$pt, this=$this")
pt
}

// aspirational implementation
// def storageType(): PType = StoredSTypePType(this, false)

override def containsPointers: Boolean =
parent.containsPointers

override def equals(obj: Any): Boolean =
obj match {
case s: SStructView =>
rename == s.rename &&
newToOldFieldMapping == s.newToOldFieldMapping &&
parent == s.parent // todo test isIsomorphicTo
case _ =>
false
}
}

class SStructViewValue(val st: SStructView, val prev: SBaseStructValue) extends SBaseStructValue {

override lazy val valueTuple: IndexedSeq[Value[_]] =
prev.valueTuple

override def subset(fieldNames: String*): SBaseStructValue =
new SStructViewValue(SStructView.subset(fieldNames.toIndexedSeq, st), prev)

override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode =
prev
.loadField(cb, st.newToOldFieldMapping(fieldIdx))
.map(cb)(_.castRename(st.virtualType.fields(fieldIdx).typ))

override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] =
prev.isFieldMissing(cb, st.newToOldFieldMapping(fieldIdx))
}

final class SStructViewSettable(st: SStructView, prev: SBaseStructSettable)
extends SStructViewValue(st, prev) with SBaseStructSettable {
override def subset(fieldNames: String*): SBaseStructValue =
new SStructViewSettable(SStructView.subset(fieldNames.toIndexedSeq, st), prev)

override def settableTuple(): IndexedSeq[Settable[_]] =
prev.settableTuple()

override def store(cb: EmitCodeBuilder, pv: SValue): Unit =
prev.store(cb, pv.asInstanceOf[SStructViewValue].prev)
}
Loading

0 comments on commit a515ccd

Please sign in to comment.