Skip to content

Writing Model

Siran Yang edited this page Jun 4, 2019 · 1 revision

In this section, we take DeepWalk and GraphSage as examples to show how to implement a graph embedding learning model by using tf_euler and tensorflow.

First, import the required libraries:

import tensorflow as tf
import tf_euler

Graph embedding learning

Using the encode/decode paradigm arXiv:1709.05584, a graph embedding learning model can be generally divided into three steps:

  1. Generate samples from node set;
  2. Encode the samples into embedding vectors;
  3. Decode the loss and evaluation metrics from the embedding vectors.

Next we write the graph embedding model in such a three-step paradigm.

DeepWalk

Here we use tf_euler.layers as meta-tools (you can also use tf.keras, or other deep learning development kits) to implement the graph embedding learning model as a tf_euler.layers.Layer, whose call function receive a tf.Tensor as the input nodes and output a tetrad representing the embeddings of the input nodes, the loss of the current Mini-batch, the evaluation metric name of the model, and the evaluation score of the current Mini-batch, respectively.

Putting DeepWalk into the above paradigm, its implementation can be divided into the following three steps:

  1. Generate a positive node by random walking from the source node, and sample negative nodes.
  2. Embed the source node, positive nodes, and negative nodes into vectors;
  3. Calculate the cross entropy loss and mrr from the embedding vectors of the source node, positive nodes, and negative nodes.
class DeepWalk(tf_euler.layers.Layer):
  def __init__(self, node_type, edge_type, max_id, dim,
               num_negs=8, walk_len=3, left_win_size=1, right_win_size=1):
    super(DeepWalk, self).__init__()
    self.node_type = node_type
    self.edge_type = edge_type
    self.max_id = max_id
    self.num_negs = num_negs
    self.walk_len = walk_len
    self.left_win_size = left_win_size
    self.right_win_size = right_win_size

    self.target_encoder = tf_euler.layers.Embedding(max_id + 1, dim)
    self.context_encoder = tf_euler.layers.Embedding(max_id + 1, dim)

  def call(self, inputs):
    src, pos, negs = self.sampler(inputs)
    embedding = self.target_encoder(src)
    embedding_pos = self.context_encoder(pos)
    embedding_negs = self.context_encoder(negs)
    loss, mrr = self.decoder(embedding, embedding_pos, embedding_negs)
    embedding = self.target_encoder(inputs)
    return (embedding, loss, 'mrr', mrr)

  def sampler(self, inputs):
    batch_size = tf.size(inputs)
    path = tf_euler.random_walk(
        inputs, [self.edge_type] * self.walk_len,
        default_node=self.max_id + 1)
    pair = tf_euler.gen_pair(path, self.left_win_size, self.right_win_size)
    num_pairs = pair.shape[1]
    src, pos = tf.split(pair, [1, 1], axis=-1)
    negs = tf_euler.sample_node(batch_size * num_pairs * self.num_negs,
                                self.node_type)
    src = tf.reshape(src, [batch_size * num_pairs, 1])
    pos = tf.reshape(pos, [batch_size * num_pairs, 1])
    negs = tf.reshape(negs, [batch_size * num_pairs, self.num_negs])
    return src, pos, negs

  def decoder(self, embedding, embedding_pos, embedding_negs):
    logits = tf.matmul(embedding, embedding_pos, transpose_b=True)
    neg_logits = tf.matmul(embedding, embedding_negs, transpose_b=True)
    true_xent = tf.nn.sigmoid_cross_entropy_with_logits(
      labels=tf.ones_like(logits), logits=logits)
    negative_xent = tf.nn.sigmoid_cross_entropy_with_logits(
      labels=tf.zeros_like(neg_logits), logits=neg_logits)
    loss = tf.reduce_sum(true_xent) + tf.reduce_sum(negative_xent)
    mrr = tf_euler.metrics.mrr_score(logits, neg_logits)
    return loss, mrr

tf_euler provides a series of operators (Euler-OP) to access the Euler graph engine in TensorFlow computation graph. Here we use tf_euler.random_walk to get the path according to the configured edge type, then use tf_euler.gen_pair to generate the pair of <source node, positive node>, and then use tf_euler.sample_node to sample the negative nodes according to the configured node type.

We can use the tf_euler.sample_node to perform node sampling on the whole graph to get mini-batch nodes for training:

tf_euler.initialize_embedded_graph('ppi') # 图数据目录
source = tf_euler.sample_node(128, tf_euler.ALL_NODE_TYPE)
source.set_shape([128])

model = DeepWalk(tf_euler.ALL_NODE_TYPE, [0, 1], 56944, 256)
_, loss, metric_name, metric = model(source)

global_step = tf.train.get_or_create_global_step()
train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss, global_step)

tf.logging.set_verbosity(tf.logging.INFO)
with tf.train.MonitoredTrainingSession(
  hooks=[
      tf.train.LoggingTensorHook({'step': global_step,
                                  'loss': loss, metric_name: metric}, 100),
      tf.train.StopAtStepHook(2000)
  ]) as sess:
  while not sess.should_stop():
    sess.run(train_op)

Running the above code could get the following output:

INFO:tensorflow:loss = 4804.9565, mrr = 0.3264798, step = 1
INFO:tensorflow:loss = 4770.668, mrr = 0.39584208, step = 101 (0.765 sec)
INFO:tensorflow:loss = 4713.837, mrr = 0.37533116, step = 201 (0.676 sec)
INFO:tensorflow:loss = 4120.8774, mrr = 0.42687973, step = 301 (0.653 sec)
INFO:tensorflow:loss = 3288.204, mrr = 0.439512, step = 401 (0.674 sec)
INFO:tensorflow:loss = 2826.6309, mrr = 0.46083882, step = 501 (0.662 sec)
INFO:tensorflow:loss = 2562.7861, mrr = 0.5067806, step = 601 (0.656 sec)
INFO:tensorflow:loss = 2336.0562, mrr = 0.55503833, step = 701 (0.670 sec)
INFO:tensorflow:loss = 2101.0967, mrr = 0.6194568, step = 801 (0.664 sec)
INFO:tensorflow:loss = 1984.6118, mrr = 0.65155166, step = 901 (0.647 sec)
INFO:tensorflow:loss = 1855.1826, mrr = 0.6955864, step = 1001 (0.621 sec)
INFO:tensorflow:loss = 1680.2745, mrr = 0.74010307, step = 1101 (0.648 sec)
INFO:tensorflow:loss = 1525.5436, mrr = 0.7830129, step = 1201 (0.628 sec)
INFO:tensorflow:loss = 1325.8943, mrr = 0.84210175, step = 1301 (0.672 sec)
INFO:tensorflow:loss = 1274.5737, mrr = 0.85022587, step = 1401 (0.689 sec)
INFO:tensorflow:loss = 1153.6146, mrr = 0.8824446, step = 1501 (0.645 sec)
INFO:tensorflow:loss = 1144.9847, mrr = 0.88094825, step = 1601 (0.645 sec)
INFO:tensorflow:loss = 961.09924, mrr = 0.92628604, step = 1701 (0.616 sec)
INFO:tensorflow:loss = 940.64496, mrr = 0.91833764, step = 1801 (0.634 sec)
INFO:tensorflow:loss = 888.75397, mrr = 0.946753, step = 1901 (0.656 sec)

GraphSage

GraphSage is an improved model of GCN which can be used to supervised learning on labeled graphs. GraphSage samples neighbors of the node and aggregates their features to get the embedding vector. Putting the supervised GraphSage into the above paradigm, its implementation can be divided into the following three steps:

  1. Get the label of the node.
  2. Perform multi-hop neighbor sampling for the node and use the node's feature/attribute as the original embedding vector. Conduct multi-layer aggregation for the original embedding vector to obtain the final embedding vector.
  3. Linearly classify the embedding vectors of the nodes to get sigmoid loss and f1 score.

Note, in each layer, GraphSage will aggregate the intermediate embedding vectors of each node and its neighbors to generate next-layer embedding vectors. Here is an example by using the Mean aggregator:

class MeanAggregator(tf_euler.layers.Layer):
  def __init__(self, dim, activation=tf.nn.relu):
    super(MeanAggregator, self).__init__()
    self.self_layer = tf_euler.layers.Dense(
        dim // 2, activation=activation, use_bias=False)
    self.neigh_layer = tf_euler.layers.Dense(
        dim // 2, activation=activation, use_bias=False)

  def call(self, inputs):
    self_embedding, neigh_embedding = inputs
    agg_embedding = tf.reduce_mean(neigh_embedding, axis=1)
    from_self = self.self_layer(self_embedding)
    from_neighs = self.neigh_layer(agg_embedding)
    return tf.concat([from_self, from_neighs], 1)

We use tf_euler.sample_fanout to perform multi-hop neighbors sampling and then use tf_euler.get_dense_feature to get the features of the nodes in each hop, and iteratively call the MeanAggregator defined above for aggregation:

class SageEncoder(tf_euler.layers.Layer):
  def __init__(self, metapath, fanouts, dim, feature_idx, feature_dim):
    super(SageEncoder, self).__init__()
    self.metapath = metapath
    self.fanouts = fanouts
    self.num_layers = len(metapath)

    self.feature_idx = feature_idx
    self.feature_dim = feature_dim

    self.aggregators = []
    for layer in range(self.num_layers):
      activation = tf.nn.relu if layer < self.num_layers - 1 else None
      self.aggregators.append(MeanAggregator(dim, activation=activation))
    self.dims = [feature_dim] + [dim] * self.num_layers

  def call(self, inputs):
    samples = tf_euler.sample_fanout(inputs, self.metapath, self.fanouts)[0]
    hidden = [
        tf_euler.get_dense_feature(sample,
                                   [self.feature_idx], [self.feature_dim])[0]
        for sample in samples]
    for layer in range(self.num_layers):
      aggregator = self.aggregators[layer]
      next_hidden = []
      for hop in range(self.num_layers - layer):
        neigh_shape = [-1, self.fanouts[hop], self.dims[layer]]
        h = aggregator((hidden[hop], tf.reshape(hidden[hop + 1], neigh_shape)))
        next_hidden.append(h)
      hidden = next_hidden
    return hidden[0]

Finally, we use tf_euler.get_dense_feature to grab the labels of nodes from the graph and linearly classify the nodes by using the final-layer's embedding vectors:

class GraphSage(tf_euler.layers.Layer):
  def __init__(self, label_idx, label_dim,
               metapath, fanouts, dim, feature_idx, feature_dim):
    super(GraphSage, self).__init__()
    self.label_idx = label_idx
    self.label_dim = label_dim
    self.encoder = SageEncoder(metapath, fanouts, dim, feature_idx, feature_dim)
    self.predict_layer = tf_euler.layers.Dense(label_dim)

  def call(self, inputs):
    nodes, labels = self.sampler(inputs)
    embedding = self.encoder(nodes)
    loss, f1 = self.decoder(embedding, labels)
    return (embedding, loss, 'f1', f1)

  def sampler(self, inputs):
    labels = tf_euler.get_dense_feature(inputs, [self.label_idx],
                                                [self.label_dim])[0]
    return inputs, labels

  def decoder(self, embedding, labels):
    logits = self.predict_layer(embedding)
    loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
    predictions = tf.floor(tf.nn.sigmoid(logits) + 0.5)
    f1 = tf_euler.metrics.f1_score(labels, predictions)
    return tf.reduce_mean(loss), f1

We can train the model using the similar method in [DeepWalk] section:

tf_euler.initialize_embedded_graph('ppi')
source = tf_euler.sample_node(512, 0)
source.set_shape([512])

model = GraphSage(0, 121, [[0], [0]], [10, 10], 256, 1, 50)
_, loss, metric_name, metric = model(source)

global_step = tf.train.get_or_create_global_step()
train_op = tf.train.AdamOptimizer(0.01).minimize(loss, global_step)

tf.logging.set_verbosity(tf.logging.INFO)
with tf.train.MonitoredTrainingSession(
  hooks=[
      tf.train.LoggingTensorHook({'step': global_step,
                                  'loss': loss, metric_name: metric}, 100),
      tf.train.StopAtStepHook(2000)
  ]) as sess:
  while not sess.should_stop():
    sess.run(train_op)

Running the above code could get the following output:

INFO:tensorflow:f1 = 0.3850271, loss = 0.69317585, step = 1
INFO:tensorflow:f1 = 0.42160043, loss = 0.5167424, step = 101 (4.987 sec)
INFO:tensorflow:f1 = 0.4489097, loss = 0.5023754, step = 201 (4.788 sec)
INFO:tensorflow:f1 = 0.4701608, loss = 0.49763823, step = 301 (4.866 sec)
INFO:tensorflow:f1 = 0.4902702, loss = 0.48410782, step = 401 (4.809 sec)
INFO:tensorflow:f1 = 0.5044798, loss = 0.4730545, step = 501 (4.851 sec)
INFO:tensorflow:f1 = 0.5104125, loss = 0.4705497, step = 601 (4.866 sec)
INFO:tensorflow:f1 = 0.51712954, loss = 0.47582737, step = 701 (4.844 sec)
INFO:tensorflow:f1 = 0.5240817, loss = 0.46666723, step = 801 (4.871 sec)
INFO:tensorflow:f1 = 0.53172356, loss = 0.45738563, step = 901 (4.837 sec)
INFO:tensorflow:f1 = 0.53270173, loss = 0.4746988, step = 1001 (4.802 sec)
INFO:tensorflow:f1 = 0.53611106, loss = 0.46039847, step = 1101 (4.882 sec)
INFO:tensorflow:f1 = 0.5402253, loss = 0.46644467, step = 1201 (4.808 sec)
INFO:tensorflow:f1 = 0.5420937, loss = 0.47356603, step = 1301 (4.820 sec)
INFO:tensorflow:f1 = 0.5462865, loss = 0.45834514, step = 1401 (4.872 sec)
INFO:tensorflow:f1 = 0.5511238, loss = 0.45826617, step = 1501 (4.848 sec)
INFO:tensorflow:f1 = 0.5543519, loss = 0.4414709, step = 1601 (4.865 sec)
INFO:tensorflow:f1 = 0.5557352, loss = 0.4589582, step = 1701 (4.836 sec)
INFO:tensorflow:f1 = 0.5591235, loss = 0.45354822, step = 1801 (4.869 sec)
INFO:tensorflow:f1 = 0.56102884, loss = 0.44353116, step = 1901 (4.885 sec)
Clone this wiki locally