Skip to content

Commit

Permalink
Speeding up Scala 3 Writing macros (#440)
Browse files Browse the repository at this point in the history
Fixes #389 and
#388

This PR moves a bunch of logic from runtime to compile time, and
optimizes whatever runtime logic remains. It also consolidates the logic
with the Scala 2 logic as much as possible, e.g. re-using `writeSnippet`
to manage writes and `CaseObjectContext`/`HugeCaseObjectContext` to
manage validation and error reporting during a read

Running some ad-hoc benchmarks on my laptop, `./mill bench.jvm.run`, it
gives a significant speedup on reads and writes, and brings it close to
the Scala 2.13 numbers (higher is better):

| Benchmark (5000ms) | Scala 3 Before | Scala 3 After | Scala 2 |

|-------------------------------------|----------------|---------------|---------|
| upickleDefault Read | 637 | 1065 | 1403 |
| upickleDefault Write | 839 | 1452 | 1549 |
| upickleDefaultByteArray Read | 582 | 1172 | 1126 |
| upickleDefaultByteArray Write | 847 | 1218 | 1277 |
| upickleDefaultBinary Read | 925 | 3853 | 3844 |
| upickleDefaultBinary Write | 1252 | 3117 | 3666 |
| upickleDefaultCached Read | 620 | 1300 | 1412 |
| upickleDefaultCached Write | 829 | 1555 | 1588 |
| upickleDefaultByteArrayCached Read | 575 | 1182 | 1095 |
| upickleDefaultByteArrayCached Write | 838 | 1223 | 1297 |
| upickleDefaultBinaryCached Read | 928 | 3825 | 3885 |
| upickleDefaultBinaryCached Write | 1266 | 2907 | 3674 |

Note that the generated code, especially for reads, is still not as
optimized as the Scala 2 versions:

1. I couldn't figure out how to generate an anonymous class with typed
fields in Scala 3, so I'm putting things in an `Array[Any]`
2. I couldn't figure out how to generate `match` statements, so I
generated a `Map[K, V]` and look it up at runtime.

Fixing this issues and moving the reading logic into the macro should
also be possible, but can happen in a separate PR

All existing tests pass. Added a regression test for the recursive Scala
3 scenario that hangs on master. Also moved a bunch of `AdvancedTests`
into the shared folder now that they work
  • Loading branch information
lihaoyi authored Feb 6, 2023
1 parent f7b61ce commit 0f87fae
Show file tree
Hide file tree
Showing 9 changed files with 338 additions and 278 deletions.
35 changes: 21 additions & 14 deletions core/src/upickle/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,25 @@ trait Types{ types =>
abstract class CaseR[V] extends SimpleReader[V]{
override def expectedMsg = "expected dictionary"
override def visitString(s: CharSequence, index: Int) = visitObject(0, true, index).visitEnd(index)
abstract class CaseObjectContext(fieldCount: Int) extends ObjVisitor[Any, V]{
trait BaseCaseObjectContext{
def storeAggregatedValue(currentIndex: Int, v: Any): Unit
var found = 0L
def visitKey(index: Int) = _root_.upickle.core.StringVisitor
var currentIndex = -1
protected def storeValueIfNotFound(i: Int, v: Any): Unit
protected def errorMissingKeys(rawArgsLength: Int, mappedArgs: Array[String]): Unit
protected def checkErrorMissingKeys(rawArgsBitset: Long): Boolean
}

abstract class CaseObjectContext(fieldCount: Int) extends ObjVisitor[Any, V] with BaseCaseObjectContext{
var found = 0L

def visitValue(v: Any, index: Int): Unit = {
if (currentIndex != -1 && ((found & (1L << currentIndex)) == 0)) {
storeAggregatedValue(currentIndex, v)
found |= (1L << currentIndex)
}
}
def visitKey(index: Int) = _root_.upickle.core.StringVisitor

protected def storeValueIfNotFound(i: Int, v: Any) = {
if ((found & (1L << i)) == 0) {
found |= (1L << i)
Expand All @@ -233,17 +241,16 @@ trait Types{ types =>
found != rawArgsBitset
}
}
abstract class HugeCaseObjectContext(fieldCount: Int) extends ObjVisitor[Any, V]{
def storeAggregatedValue(currentIndex: Int, v: Any): Unit
abstract class HugeCaseObjectContext(fieldCount: Int) extends ObjVisitor[Any, V] with BaseCaseObjectContext{
var found = new Array[Long](fieldCount / 64 + 1)
var currentIndex = -1

def visitValue(v: Any, index: Int): Unit = {
if (currentIndex != -1 && ((found(currentIndex / 64) & (1L << currentIndex)) == 0)) {
storeAggregatedValue(currentIndex, v)
found(currentIndex / 64) |= (1L << currentIndex)
}
}
def visitKey(index: Int) = _root_.upickle.core.StringVisitor

protected def storeValueIfNotFound(i: Int, v: Any) = {
if ((found(i / 64) & (1L << i)) == 0) {
found(i / 64) |= (1L << i)
Expand All @@ -259,7 +266,7 @@ trait Types{ types =>
"missing keys in dictionary: " + keys.mkString(", ")
)
}
protected def checkErrorMissingKeys(rawArgsLength: Int) = {
protected def checkErrorMissingKeys(rawArgsLength: Long) = {
var bits = 0
for(v <- found) bits += java.lang.Long.bitCount(v)
bits != rawArgsLength
Expand All @@ -277,16 +284,16 @@ trait Types{ types =>
ctx.visitEnd(-1)
}
}
protected def writeSnippet[R, V](objectAttributeKeyWriteMap: CharSequence => CharSequence,
ctx: _root_.upickle.core.ObjVisitor[_, R],
mappedArgsI: String,
w: Writer[V],
value: V) = {
def writeSnippet[R, V](objectAttributeKeyWriteMap: CharSequence => CharSequence,
ctx: _root_.upickle.core.ObjVisitor[_, R],
mappedArgsI: String,
w: Any,
value: Any) = {
val keyVisitor = ctx.visitKey(-1)
ctx.visitKeyValue(
keyVisitor.visitString(objectAttributeKeyWriteMap(mappedArgsI), -1)
)
ctx.narrow.visitValue(w.write(ctx.subVisitor, value), -1)
ctx.narrow.visitValue(w.asInstanceOf[Writer[Any]].write(ctx.subVisitor, value), -1)
}
}
class SingletonR[T](t: T) extends CaseR[T]{
Expand Down
114 changes: 0 additions & 114 deletions implicits/src-3/upickle/implicits/CaseClassReader.scala

This file was deleted.

80 changes: 0 additions & 80 deletions implicits/src-3/upickle/implicits/CaseClassWriter.scala

This file was deleted.

109 changes: 105 additions & 4 deletions implicits/src-3/upickle/implicits/Readers.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,110 @@
package upickle.implicits

import upickle.core.{ Visitor, ObjVisitor, Annotator }
import compiletime.summonInline
import deriving.Mirror
import upickle.core.{Annotator, ObjVisitor, Visitor, Abort}
import upickle.implicits.macros.EnumDescription

import deriving._, compiletime._

trait ReadersVersionSpecific extends CaseClassReaderPiece:
trait ReadersVersionSpecific extends MacrosCommon:
this: upickle.core.Types with Readers with Annotator =>

class CaseReader[T](visitors0: => Product,
fromProduct: Product => T,
keyToIndex: Map[String, Int],
defaultParams: Array[() => Any],
missingKeyCount: Long) extends CaseR[T] {

val paramCount = keyToIndex.size
lazy val visitors = visitors0
lazy val indexToKey = keyToIndex.map(_.swap)

trait ObjectContext extends ObjVisitor[Any, T] with BaseCaseObjectContext{
private val params = new Array[Any](paramCount)

def storeAggregatedValue(currentIndex: Int, v: Any): Unit = params(currentIndex) = v

def subVisitor: Visitor[_, _] =
if (currentIndex == -1) upickle.core.NoOpVisitor
else visitors.productElement(currentIndex).asInstanceOf[Visitor[_, _]]

def visitKeyValue(v: Any): Unit =
val k = objectAttributeKeyReadMap(v.toString).toString
currentIndex = keyToIndex.getOrElse(k, -1)

def visitEnd(index: Int): T =
var i = 0
while (i < paramCount)
defaultParams(i) match
case null =>
case computeDefault => storeValueIfNotFound(i, computeDefault())

i += 1

// Special-case 64 because java bit shifting ignores any RHS values above 63
// https://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.19
if (this.checkErrorMissingKeys(missingKeyCount))
this.errorMissingKeys(paramCount, indexToKey.toSeq.sortBy(_._1).map(_._2).toArray)

fromProduct(new Product {
def canEqual(that: Any): Boolean = true
def productArity: Int = params.length
def productElement(i: Int): Any = params(i)
})
}
override def visitObject(length: Int,
jsonableKeys: Boolean,
index: Int) =
if (paramCount <= 64) new CaseObjectContext(paramCount) with ObjectContext
else new HugeCaseObjectContext(paramCount) with ObjectContext
}

class EnumReader[T](f: String => T, description: EnumDescription) extends SimpleReader[T] :
override def expectedMsg = "expected string enumeration"

override def visitString(s: CharSequence, index: Int) = {
val str = s.toString
try {
f(str)
} catch {
case _: IllegalArgumentException =>
throw new Abort(s"Value '$str' was not found in enumeration ${description.pretty}")
}
}
end EnumReader

inline def macroR[T](using m: Mirror.Of[T]): Reader[T] = inline m match {
case m: Mirror.ProductOf[T] =>

val reader = new CaseReader(
compiletime.summonAll[Tuple.Map[m.MirroredElemTypes, Reader]],
m.fromProduct(_),
macros.fieldLabels[T].map(_._2).zipWithIndex.toMap,
macros.getDefaultParamsArray[T],
macros.checkErrorMissingKeysCount[T]()
)

if macros.isMemberOfSealedHierarchy[T] then annotate(reader, macros.fullClassName[T])
else reader

case m: Mirror.SumOf[T] =>
inline compiletime.erasedValue[T] match {
case _: scala.reflect.Enum =>
val valueOf = macros.enumValueOf[T]
val description = macros.enumDescription[T]
new EnumReader[T](valueOf, description)
case _ =>
val readers: List[Reader[_ <: T]] = compiletime.summonAll[Tuple.Map[m.MirroredElemTypes, Reader]]
.toList
.asInstanceOf[List[Reader[_ <: T]]]

Reader.merge[T](readers: _*)
}
}

inline given[T <: Singleton : Mirror.Of]: Reader[T] = macroR[T]

// see comment in MacroImplicits as to why Dotty's extension methods aren't used here
implicit class ReaderExtension(r: Reader.type):
inline def derived[T](using Mirror.Of[T]): Reader[T] = macroR[T]
end ReaderExtension
end ReadersVersionSpecific
Loading

0 comments on commit 0f87fae

Please sign in to comment.