Skip to content

Factor Graph Generation from Scala Code

Sebastian Riedel edited this page Mar 1, 2014 · 21 revisions

Conceptually generating factor graphs happens in three stages

  1. Converting the sample space into a set of nodes isomorphic to the sample space (for joint state of all nodes there is an element in the sample space, and vice versa)
  2. Converting the predicate/filter to assignments and domain restrictions for nodes
  3. Converting the objective to a set of factors

Sample Space to Nodes

Each sample space expression can be considered as a tree of sub spaces. For each node/sub space in this tree we create a class of type Structure. Objects of this class store all nodes (recursively using other Structure instances if necessary) needed to represent values in the sub-space.

/**
 * A structure is a collection of MPGraph nodes whose assignments correspond to values of type `T`.
 * @tparam T the type of the values this structure can generate.
 */
trait Structure[+T] {
  /**
   * @return all nodes in this structure (including nodes of substructures)
   */
  def nodes(): Iterator[MPGraph.Node]
  /**
   * @return the value that the current assignment to all nodes is representing.
   */
  def value(): T
  /**
   * Sets all nodes to their argmax belief. todo: this could be generically implemented using nodes().
   */
  def setToArgmax()
  /**
   * resets the state of all nodes.
   */
  def resetSetting()
  /**
   * @return is there a next state that the structure can take on or have we iterated over all its states.
   */
  def hasNextSetting: Boolean
  /**
   * set the structure to its next state by changing one or more of its nodes assignments.
   */
  def nextSetting()
}

Once we have converted our sample space into this format, we can run inference on the MPGraph (assuming we have added factors accordingly), call setToArgmax and then return value() to the client.

We are considering the following types of spaces.

Atomic Iterables

Consider a sample space

Seq(1,2,3)

Roughly speaking, this sample space is mapped to the Structure class

final class AtomicStructure extends Structure[Int] {
  val atomDom = Seq(1,2,3).toArray;
  val atomIndex = atomDom.zipWithIndex.toMap;
  //mpGraph is a variable storing an MPGraph somewhere in scope.
  val node = mpGraph.addNode(atomDom.length);
  private def updateValue(): scala.Unit = node.value = node.domain(node.setting);
  def value() = atomDom(node.value);
  def nodes() = Iterator(node);
  def resetSetting(): scala.Unit = node.setting = -1;
  def hasNextSetting = node.setting.$less(node.dim.$minus(1));
  def nextSetting() = {
    node.setting.$plus$eq(1);
    updateValue()
  };
  def setToArgmax(): scala.Unit = {
    node.setting = MoreArrayOps.maxIndex(node.b);
    updateValue()
  };
  //this method is not in the trait for covariance reasons, it's still essential and every structure
  //needs it. Generally calls to methods in the Structure interface should be non-polymorphic anyway
  //so the trait is really just there to give a guideline what needs to be implemented.
  final def observe(value: ml.wolfe.macros.TestIris.Observed): scala.Unit = {
    val index = atomIndex3(value);
    node.domain = Array(index);
    node.dim = 1
  }
}

Case Class Sample Spaces

Consider a sample space

case class Person(name:String,age:Int)
val names = Seq("Vivek","Sameer")
all(Person)(c(names,Range(0,100))

We create the following Structure Class

final class PersonStructure extends Structure[Person] {
  final class AtomicStructure1 extends Structure[String] { /* see AtomicStructure above ... */ }
  final class AtomicStructure2 extends Structure[Int] { ... }   
  val name = new AtomicStructure1();
  val age = new AtomicStructure2();
  def fields: Iterator[Structure[Any]] = Iterator(x, y);
  def value(): Person = new Person(name.value(), age.value());
  def nodes(): Iterator[Node] = fields.flatMap(((x$1) => x$1.nodes()));
  private var iterator: Iterator[Unit] = _;
  def resetSetting(): scala.Unit = iterator = Structure.settingsIterator(List(x, y).reverse)();
  def hasNextSetting = iterator.hasNext;
  def nextSetting = iterator.next;
  def setToArgmax(): scala.Unit = fields.foreach(((x$2) => x$2.setToArgmax()));
  def observe(value: Person): scala.Unit = {
    name.observe(value.name);
    age.observe(value.age)
  }
}

Maps / Functions

//implemented, but documentation missing

Predicates to Assignments and Domain Restrictions

Let's assume we have the case class sample space from above, and the following argmax macro call

argmax(sampleSpace)(p => p.name == "Vivek")(obj)

We incorporate this predicate into the factor graph setup as follows. Most important for this step, and the objective part later, are structure matchers. Structure matchers are functions that take a tree and return the structure sub-object corresponding to the tree, if any. With such a matcher, and the observe methods we have on structure objects (not in the trait currently but in its implementations), incorporating this predicate is easy:

def matchStructure(tree:Tree):Option[Tree] = ???
predicate match {
  case q"$x == $value" => matchStructure(x) match {
    case Some(structure) => addToCode(q"$structure.observe($value)") 
    case _ => //return the predicate as is and just turn it into a deterministic factor
  }
}

(Note that this is simplified code, the actual code looks a little different) There is a vast set of todos on this: First, it currently cannot even deal with conjunctions, although this should be easy to do.

Objective to Factors

Akin to the sample space to node mapping, the objective to factor mapping traverses through the objective tree. However, currently we don't introduce an own generated class for each node in the objective tree. Instead we have one scala expression per node that generates the factors (but doesn't keep them in an object).

Atomic / Leaf Objectives

The simplest conversion case arises when the objective (or, during traversal, a sub-objective) is atomic (i.e. a leaf) that cannot be further broken down. Assume again the Person sample space from above and the following right hand side of the objective

atomicScoringFunctionDefinedElsewhere(person.name)

where def atomicScoringFunctionDefinedElsewhere(name:String):Double is a function defined somewhere we don't have access to its definition (or we have annotated with a special Atomic annotation to be implemented later).

We proceed in the following steps

  1. Use a structure matcher to find all structures involved in the objective (in the finest possible granularity). Assuming that the root (case class) structure is called root, we could get the list of structures List(root.name).

  2. Use a structure matcher to replace expressions that refer to data fields with expressions that refer to values of the generated Structure object. For example, assume that the root structure is called root, then person.name is replaced by root.name.value(). Hence the objective becomes atomicScoringFunctionDefinedElsewhere(root.name.value())

  3. Iterate over all settings of the involved structures and call atomicScoringFunctionDefinedElsewhere(root.name.value()) for each setting to get the score for this setting and create the score table of the factor.