Skip to content

Commit

Permalink
Recursively resolve refinements (#272)
Browse files Browse the repository at this point in the history
Class refinements and block-based default can now chain, eg
PassiveHeader -> PinConnector254 and PinConnector254 ->
PinConnector254Vertical will result in PassiveHeader ->
PinConnector254Vertical.
  • Loading branch information
ducky64 authored Jul 25, 2023
1 parent 4766498 commit c1034d0
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 36 deletions.
75 changes: 42 additions & 33 deletions compiler/src/main/scala/edg/compiler/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import edg.util.{DependencyGraph, Errorable, SingleWriteHashMap}
import edg.wir.ProtoUtil._
import edg.wir._
import edg.{ExprBuilder, wir}
import edgir.elem.elem
import edgir.expr.expr
import edgir.ref.ref
import edgir.init.init
Expand Down Expand Up @@ -530,45 +531,52 @@ class Compiler private (
}
}

// Recursively resolves the refined library, accounting for instance, class, and default refinements.
// Returns none of no further refinement is specified.
protected def resolveRefinementLibrary(
path: DesignPath,
blockLibrary: ref.LibraryPath,
blockPb: elem.HierarchyBlock,
applyInstanceRefinement: Boolean = true // should only be true at the beginning
): Option[(ref.LibraryPath, elem.HierarchyBlock)] = {
val nextLibrary = Option.when(applyInstanceRefinement)(refinements.instanceRefinements.get(path)).flatten
.orElse(refinements.classRefinements.get(blockLibrary))
.orElse(blockPb.defaultRefinement)

nextLibrary.map { nextLibrary =>
val nextBlockPb = library.getBlock(nextLibrary) match {
case Errorable.Success(blockPb) => blockPb
case Errorable.Error(err) =>
errors += CompilerError.LibraryError(path, nextLibrary, err)
elem.HierarchyBlock()
}
resolveRefinementLibrary(path, nextLibrary, nextBlockPb, false).getOrElse((nextLibrary, nextBlockPb))
}
}

// Given a block library at some path, expand it and link it in the parent.
// Does not elaborate the internals (including connections / assertions / assignments), which is
// a separate phase that (for generators) may be gated on additional parameters.
// Handles class type refinements and adds default parameters and class-based value refinements
// For the generator, this will be a skeleton block.
protected def expandBlock(path: DesignPath): Unit = {
import edgir.elem.elem

val block = resolveBlock(path).asInstanceOf[wir.BlockLibrary]

// check for and apply block-side default refinement, if defined
val libraryBlockPb = library.getBlock(block.target, block.mixins) match {
val prerefineBlockPb = library.getBlock(block.target, block.mixins) match {
case Errorable.Success(blockPb) => blockPb
case Errorable.Error(err) =>
errors += CompilerError.LibraryError(path, block.target, err)
elem.HierarchyBlock()
}
val refinementLibraryPath = refinements.instanceRefinements.get(path).orElse(
refinements.classRefinements.get(block.target).orElse(
libraryBlockPb.defaultRefinement
)
)

// actually instantiate the block
val unrefinedType = if (refinementLibraryPath.isDefined) Some(block.target) else None
val blockLibraryPath = refinementLibraryPath.getOrElse(block.target)
val blockMixins = if (refinementLibraryPath.isDefined) Seq() else block.mixins // discard mixins if refined

val blockPb = library.getBlock(blockLibraryPath, blockMixins) match {
case Errorable.Success(blockPb) =>
blockPb
case Errorable.Error(err) =>
errors += CompilerError.LibraryError(path, blockLibraryPath, err)
elem.HierarchyBlock()
}
val (isRefined, refinedLibrary, refinedPb) =
resolveRefinementLibrary(path, block.target, prerefineBlockPb) match {
case Some((refinedPath, refinedPb)) => (true, refinedPath, refinedPb)
case None => (false, block.target, prerefineBlockPb)
}

// add class-based refinements - must be set before refinement params
// note that this operates on the post-refinement class
val blockAllClasses = Seq(blockPb.selfClass, blockPb.superclasses, blockPb.superSuperclasses).flatten
val blockAllClasses = Seq(refinedPb.selfClass, refinedPb.superclasses, refinedPb.superSuperclasses).flatten
filterRefinementClassValues(blockAllClasses, refinementClassValuesByClass).foreach {
case ((refinementClass, postfix), value) =>
val paramPath = path ++ postfix
Expand All @@ -584,44 +592,45 @@ class Compiler private (
}

// additional processing needed for the refinement case
val unrefinedType = if (isRefined) Some(block.target) else None
if (unrefinedType.isDefined) {
val refinedNewParams = blockPb.params.toSeqMap.keys.toSet -- libraryBlockPb.params.toSeqMap.keys
val refinedNewParams = refinedPb.params.toSeqMap.keys.toSet -- prerefineBlockPb.params.toSeqMap.keys
refinedNewParams.foreach { refinedNewParam => // add subclass (refinement) default params
blockPb.paramDefaults.get(refinedNewParam).foreach { refinedDefault =>
refinedPb.paramDefaults.get(refinedNewParam).foreach { refinedDefault =>
constProp.addAssignExpr(
path.asIndirect + refinedNewParam,
refinedDefault,
path,
s"(default)${blockLibraryPath.toSimpleString}.$refinedNewParam"
s"(default)${refinedLibrary.toSimpleString}.$refinedNewParam"
)
}
}
val refinedNewPorts = blockPb.ports.toSeqMap.keys.toSet -- libraryBlockPb.ports.toSeqMap.keys
val refinedNewPorts = refinedPb.ports.toSeqMap.keys.toSet -- prerefineBlockPb.ports.toSeqMap.keys
refinedNewPorts.foreach { refinedNewPort => // add subclass (refinement) non-connected
blockPb.ports(refinedNewPort).is match {
refinedPb.ports(refinedNewPort).is match {
case _: elem.PortLike.Is.LibElem =>
constProp.addAssignValue(
path.asIndirect + refinedNewPort + IndirectStep.IsConnected,
BooleanValue(false),
path,
s"(refined_not_connected)${blockLibraryPath.toSimpleString}.$refinedNewPort"
s"(refined_not_connected)${refinedLibrary.toSimpleString}.$refinedNewPort"
)
case _: elem.PortLike.Is.Array =>
constProp.addAssignValue(
path.asIndirect + refinedNewPort + IndirectStep.Allocated,
ArrayValue(Seq()),
path,
s"(refined_not_connected)${blockLibraryPath.toSimpleString}.$refinedNewPort"
s"(refined_not_connected)${refinedLibrary.toSimpleString}.$refinedNewPort"
)
case _ => throw new IllegalArgumentException(s"unknown port $refinedNewPort")
}
}
}

val newBlock = if (blockPb.generator.isEmpty) {
new wir.Block(blockPb, unrefinedType, block.mixins)
val newBlock = if (refinedPb.generator.isEmpty) {
new wir.Block(refinedPb, unrefinedType, block.mixins)
} else {
new wir.Generator(blockPb, unrefinedType, block.mixins)
new wir.Generator(refinedPb, unrefinedType, block.mixins)
}

val (parentPath, blockName) = path.split
Expand Down
77 changes: 77 additions & 0 deletions compiler/src/test/scala/edg/compiler/CompilerRefinementTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ class CompilerRefinementTest extends AnyFlatSpec with CompilerTestUtil {
"newArray" -> Port.Array("port", Seq(), Port.Library("port")),
)
),
Block.Block(
"subsubclassBlock",
superclasses = Seq("subclassBlock"),
superSuperclasses = Seq("superclassBlock", "superclassDefaultBlock"),
params = SeqMap(
"superParam" -> ValInit.Integer,
"subParam" -> ValInit.Integer,
),
ports = SeqMap(
"port" -> Port.Library("port"),
)
),
Block.Block(
"block", // specifically no superclass
params = SeqMap(
Expand Down Expand Up @@ -296,4 +308,69 @@ class CompilerRefinementTest extends AnyFlatSpec with CompilerTestUtil {
))
testCompile(blockDefaultInputDesign, library, expectedDesign = Some(expected))
}

"Compiler on design with chained (instance + class) refinement" should "work" in {
val expected = Design(Block.Block(
"topDesign",
blocks = SeqMap(
"block" -> Block.Block(
"subsubclassBlock",
superclasses = Seq("subclassBlock"),
superSuperclasses = Seq("superclassBlock", "superclassDefaultBlock"),
prerefine = "superclassBlock",
params = SeqMap(
"superParam" -> ValInit.Integer,
"subParam" -> ValInit.Integer,
),
ports = SeqMap(
"port" -> Port.Port(selfClass = "port"),
)
),
)
))
testCompile(
inputDesign,
library,
refinements = Refinements(
instanceRefinements = Map(DesignPath() + "block" -> LibraryPath("subclassBlock")),
classRefinements = Map(LibraryPath("subclassBlock") -> LibraryPath("subsubclassBlock"))
),
expectedDesign = Some(expected)
)
}

"Compiler on design with chained (default + class) refinement" should "work" in {
val blockDefaultInputDesign = Design(Block.Block(
"topDesign",
blocks = SeqMap(
"block" -> Block.Library("superclassDefaultBlock"),
)
))
val expected = Design(Block.Block(
"topDesign",
blocks = SeqMap(
"block" -> Block.Block(
"subsubclassBlock",
superclasses = Seq("subclassBlock"),
superSuperclasses = Seq("superclassBlock", "superclassDefaultBlock"),
prerefine = "superclassDefaultBlock",
params = SeqMap(
"superParam" -> ValInit.Integer,
"subParam" -> ValInit.Integer,
),
ports = SeqMap(
"port" -> Port.Port(selfClass = "port"),
)
),
)
))
testCompile(
blockDefaultInputDesign,
library,
refinements = Refinements(
classRefinements = Map(LibraryPath("subclassBlock") -> LibraryPath("subsubclassBlock"))
),
expectedDesign = Some(expected)
)
}
}
Binary file modified edg_core/resources/edg-compiler-precompiled.jar
Binary file not shown.
2 changes: 1 addition & 1 deletion examples/test_robotowl.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def refinements(self) -> Refinements:
(['reg_12v', 'power_path', 'inductor', 'manual_frequency_rating'], Range(0, 7e6)),
],
class_refinements=[
(PassiveConnector, PinHeader254Horizontal), # default connector series unless otherwise specified
(PassiveConnector, PinHeader254), # default connector series unless otherwise specified
(PinHeader254, PinHeader254Horizontal),
(TestPoint, CompactKeystone5015),
(Speaker, ConnectorSpeaker),
Expand Down
2 changes: 1 addition & 1 deletion examples/test_switch_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def refinements(self) -> Refinements:
(['mcu'], Esp32_Wroom_32),
(['reg_3v3'], Ld1117),

(['conn', 'conn'], PinHeader254Vertical),
(['conn', 'conn'], PinHeader254),
],
instance_values=[
(['mcu', 'pin_assigns'], [
Expand Down
2 changes: 1 addition & 1 deletion examples/test_usb_uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def contents(self) -> None:
def refinements(self) -> Refinements:
return super().refinements() + Refinements(
instance_refinements=[
(['out', 'conn'], PinHeader254Vertical),
(['out', 'conn'], PinHeader254),
(['reg_3v3'], Ap2204k),
],
instance_values=[
Expand Down

0 comments on commit c1034d0

Please sign in to comment.