-
Notifications
You must be signed in to change notification settings - Fork 17
Factor Graph Generation from Scala Code
Conceptually generating factor graphs happens in three stages
- Converting the sample space into a set of nodes isomorphic to the sample space (for joint state of all nodes there is an element in the sample space, and vice versa)
- Converting the predicate/filter to assignments and domain restrictions for nodes
- Converting the objective to a set of factors
Each sample space expression can be considered as a tree of sub spaces. For each node/sub space in this tree we create a class of type Structure
. Objects of this class store all nodes (recursively using other Structure instances if necessary) needed to represent values in the sub-space.
/**
* A structure is a collection of MPGraph nodes whose assignments correspond to values of type `T`.
* @tparam T the type of the values this structure can generate.
*/
trait Structure[+T] {
/**
* @return all nodes in this structure (including nodes of substructures)
*/
def nodes(): Iterator[MPGraph.Node]
/**
* @return the value that the current assignment to all nodes is representing.
*/
def value(): T
/**
* Sets all nodes to their argmax belief. todo: this could be generically implemented using nodes().
*/
def setToArgmax()
/**
* resets the state of all nodes.
*/
def resetSetting()
/**
* @return is there a next state that the structure can take on or have we iterated over all its states.
*/
def hasNextSetting: Boolean
/**
* set the structure to its next state by changing one or more of its nodes assignments.
*/
def nextSetting()
}
Once we have converted our sample space into this format, we can run inference on the MPGraph (assuming we have added factors accordingly), call setToArgmax
and then return value()
to the client.
We are considering the following types of spaces.
Consider a sample space
Seq(1,2,3)
Roughly speaking, this sample space is mapped to the Structure class
final class AtomicStructure extends Structure[Int] {
val atomDom = Seq(1,2,3).toArray;
val atomIndex = atomDom.zipWithIndex.toMap;
//mpGraph is a variable storing an MPGraph somewhere in scope.
val node = mpGraph.addNode(atomDom.length);
private def updateValue(): scala.Unit = node.value = node.domain(node.setting);
def value() = atomDom(node.value);
def nodes() = Iterator(node);
def resetSetting(): scala.Unit = node.setting = -1;
def hasNextSetting = node.setting.$less(node.dim.$minus(1));
def nextSetting() = {
node.setting.$plus$eq(1);
updateValue()
};
def setToArgmax(): scala.Unit = {
node.setting = MoreArrayOps.maxIndex(node.b);
updateValue()
};
//this method is not in the trait for covariance reasons, it's still essential and every structure
//needs it. Generally calls to methods in the Structure interface should be non-polymorphic anyway
//so the trait is really just there to give a guideline what needs to be implemented.
final def observe(value: ml.wolfe.macros.TestIris.Observed): scala.Unit = {
val index = atomIndex3(value);
node.domain = Array(index);
node.dim = 1
}
}
Consider a sample space
case class Person(name:String,age:Int)
val names = Seq("Vivek","Sameer")
all(Person)(c(names,Range(0,100))
We create the following Structure Class
final class PersonStructure extends Structure[Person] {
final class AtomicStructure1 extends Structure[String] { /* see AtomicStructure above ... */ }
final class AtomicStructure2 extends Structure[Int] { ... }
val name = new AtomicStructure1();
val age = new AtomicStructure2();
def fields: Iterator[Structure[Any]] = Iterator(x, y);
def value(): Person = new Person(name.value(), age.value());
def nodes(): Iterator[Node] = fields.flatMap(((x$1) => x$1.nodes()));
private var iterator: Iterator[Unit] = _;
def resetSetting(): scala.Unit = iterator = Structure.settingsIterator(List(x, y).reverse)();
def hasNextSetting = iterator.hasNext;
def nextSetting = iterator.next;
def setToArgmax(): scala.Unit = fields.foreach(((x$2) => x$2.setToArgmax()));
def observe(value: Person): scala.Unit = {
name.observe(value.name);
age.observe(value.age)
}
}
//implemented, but documentation missing
Let's assume we have the case class sample space from above, and the following argmax
macro call
argmax(sampleSpace)(p => p.name == "Vivek")(obj)
We incorporate this predicate into the factor graph setup as follows. Most important for this step, and the objective part later, are structure matchers. Structure matchers are functions that take a tree and return the structure sub-object corresponding to the tree, if any. With such a matcher, and the observe
methods we have on structure objects (not in the trait currently but in its implementations), incorporating this predicate is easy:
def matchStructure(tree:Tree):Option[Tree] = ???
predicate match {
case q"$x == $value" => matchStructure(x) match {
case Some(structure) => addToCode(q"$structure.observe($value)")
case _ => //return the predicate as is and just turn it into a deterministic factor
}
}
(Note that this is simplified code, the actual code looks a little different) There is a vast set of todos on this: First, it currently cannot even deal with conjunctions, although this should be easy to do.
Akin to the sample space to node mapping, the objective to factor mapping traverses through the objective tree. However, currently we don't introduce an own generated class for each node in the objective tree. Instead we have one scala expression per node that generates the factors (but doesn't keep them in an object).
The simplest conversion case arises when the objective (or, during traversal, a sub-objective) is atomic (i.e. a leaf) that cannot be further broken down. Assume again the Person sample space from above and the following right hand side of the objective
atomicScoringFunctionDefinedElsewhere(person.name)
where def atomicScoringFunctionDefinedElsewhere(name:String):Double
is a function defined somewhere we don't have access to its definition (or we have annotated with a special Atomic
annotation to be implemented later).
We proceed in the following steps
-
Use a structure matcher to find all structures involved in the objective (in the finest possible granularity). Assuming that the root (case class) structure is called
root
, we could get the list of structuresList(root.name)
. -
Use a structure matcher to replace expressions that refer to data fields with expressions that refer to values of the generated
Structure
object. For example, assume that the root structure is calledroot
, thenperson.name
is replaced byroot.name.value()
. Hence the objective becomesatomicScoringFunctionDefinedElsewhere(root.name.value())
-
Iterate over all settings of the involved structures and call
atomicScoringFunctionDefinedElsewhere(root.name.value())
for each setting to get the score for this setting and create the score table of the factor.