Skip to content

Commit

Permalink
vectorize calls to log densities in examples; #73
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed May 14, 2016
1 parent 5086e49 commit cedf2ed
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 63 deletions.
2 changes: 1 addition & 1 deletion examples/beta_bernoulli_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self):

def log_prob(self, xs, zs):
log_prior = beta.logpdf(zs, a=1.0, b=1.0)
log_lik = tf.pack([tf.reduce_sum(bernoulli.logpmf(xs, z)) \
log_lik = tf.pack([tf.reduce_sum(bernoulli.logpmf(xs, z))
for z in tf.unpack(zs)])
return log_lik + log_prior

Expand Down
2 changes: 1 addition & 1 deletion examples/beta_bernoulli_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self):

def log_prob(self, xs, zs):
log_prior = beta.logpdf(zs, a=1.0, b=1.0)
log_lik = tf.pack([tf.reduce_sum(bernoulli.logpmf(xs, z)) \
log_lik = tf.pack([tf.reduce_sum(bernoulli.logpmf(xs, z))
for z in tf.unpack(zs)])
return log_lik + log_prior

Expand Down
2 changes: 1 addition & 1 deletion examples/convolutional_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import tensorflow as tf

from convolutional_vae_util import deconv2d
from edward import Variational, Normal
from edward.variationals import Variational, Normal
from progressbar import ETA, Bar, Percentage, ProgressBar
from scipy.misc import imsave
from tensorflow.examples.tutorials.mnist import input_data
Expand Down
44 changes: 20 additions & 24 deletions examples/mixture_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,38 +57,34 @@ def __init__(self, K, D):
self.c = 10
self.alpha = tf.ones([K])

def unpack_params(self, z):
"""Unpack parameters from a flattened vector."""
pi = z[0:self.K]
mus = z[self.K:(self.K+self.K*self.D)]
sigmas = z[(self.K+self.K*self.D):(self.K+2*self.K*self.D)]
def unpack_params(self, zs):
"""Unpack sets of parameters from a flattened matrix."""
pi = zs[:, 0:self.K]
mus = zs[:, self.K:(self.K+self.K*self.D)]
sigmas = zs[:, (self.K+self.K*self.D):(self.K+2*self.K*self.D)]
return pi, mus, sigmas

def log_prob(self, xs, zs):
"""Returns a vector [log p(xs, zs[1,:]), ..., log p(xs, zs[S,:])]."""
N = get_dims(xs)[0]
# Loop over each mini-batch zs[b,:]
log_prob = []
for z in tf.unpack(zs):
pi, mus, sigmas = self.unpack_params(z)
log_prior = dirichlet.logpdf(pi, self.alpha)
pi, mus, sigmas = self.unpack_params(zs)
log_prior = dirichlet.logpdf(pi, self.alpha)
log_prior += tf.reduce_sum(norm.logpdf(mus, 0, np.sqrt(self.c)), 1)
log_prior += tf.reduce_sum(invgamma.logpdf(sigmas, self.a, self.b), 1)

log_lik = []
n_minibatch = get_dims(zs)[0]
for s in xrange(n_minibatch):
log_lik_z = N*tf.reduce_sum(tf.log(pi), 1)
for k in xrange(self.K):
log_prior += norm.logpdf(mus[k*self.D], 0, np.sqrt(self.c))
log_prior += norm.logpdf(mus[k*self.D+1], 0, np.sqrt(self.c))
log_prior += invgamma.logpdf(sigmas[k*self.D], self.a, self.b)
log_prior += invgamma.logpdf(sigmas[k*self.D+1], self.a, self.b)
log_lik_z += tf.reduce_sum(multivariate_normal.logpdf(xs,
mus[s, (k*self.D):((k+1)*self.D)],
sigmas[s, (k*self.D):((k+1)*self.D)]))

log_lik = tf.constant(0.0, dtype=tf.float32)
for x in tf.unpack(xs):
for k in xrange(self.K):
log_lik += tf.log(pi[k])
log_lik += multivariate_normal.logpdf(x,
mus[(k*self.D):((k+1)*self.D)],
sigmas[(k*self.D):((k+1)*self.D)])
log_lik += [log_lik_z]

log_prob += [log_prior + log_lik]

return tf.pack(log_prob)
return log_prior + tf.pack(log_lik)

ed.set_seed(42)
x = np.loadtxt('data/mixture_data.txt', dtype='float32', delimiter=',')
Expand All @@ -101,4 +97,4 @@ def log_prob(self, xs, zs):
variational.add(InvGamma(model.K*model.D))

inference = ed.MFVI(model, variational, data)
inference.run(n_iter=10000, n_minibatch=5, n_data=5)
inference.run(n_iter=500, n_minibatch=5, n_data=5)
52 changes: 24 additions & 28 deletions examples/mixture_gaussian_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,43 +56,39 @@ def __init__(self, K, D):
self.c = 10
self.alpha = tf.ones([K])

def unpack_params(self, z):
"""Unpack parameters from a flattened vector."""
pi = z[0:self.K]
mus = z[self.K:(self.K+self.K*self.D)]
sigmas = z[(self.K+self.K*self.D):(self.K+2*self.K*self.D)]
def unpack_params(self, zs):
"""Unpack sets of parameters from a flattened matrix."""
pi = zs[:, 0:self.K]
mus = zs[:, self.K:(self.K+self.K*self.D)]
sigmas = zs[:, (self.K+self.K*self.D):(self.K+2*self.K*self.D)]
# Do the unconstrained to constrained transformation for MAP here.
pi = tf.sigmoid(pi)
pi = tf.concat(1, [pi[:, 0:(self.K-1)],
tf.expand_dims(1.0 - tf.reduce_sum(pi[:, 0:(self.K-1)], 1), 0)])
sigmas = tf.nn.softplus(sigmas)
return pi, mus, sigmas

def log_prob(self, xs, zs):
"""Returns a vector [log p(xs, zs[1,:]), ..., log p(xs, zs[S,:])]."""
N = get_dims(xs)[0]
# Loop over each mini-batch zs[b,:]
log_prob = []
for z in tf.unpack(zs):
# Do the unconstrained to constrained transformation for MAP here.
pi, mus, sigmas = self.unpack_params(z)
pi = tf.sigmoid(pi)
pi = tf.concat(0, [pi[0:(self.K-1)],
tf.expand_dims(1.0 - tf.reduce_sum(pi[0:(self.K-1)]), 0)])
sigmas = tf.nn.softplus(sigmas)
log_prior = dirichlet.logpdf(pi, self.alpha)
for k in xrange(self.K):
log_prior += norm.logpdf(mus[k*self.D], 0, np.sqrt(self.c))
log_prior += norm.logpdf(mus[k*self.D+1], 0, np.sqrt(self.c))
log_prior += invgamma.logpdf(sigmas[k*self.D], self.a, self.b)
log_prior += invgamma.logpdf(sigmas[k*self.D+1], self.a, self.b)
pi, mus, sigmas = self.unpack_params(zs)
log_prior = dirichlet.logpdf(pi, self.alpha)
log_prior += tf.reduce_sum(norm.logpdf(mus, 0, np.sqrt(self.c)))
log_prior += tf.reduce_sum(invgamma.logpdf(sigmas, self.a, self.b))

log_lik = tf.constant(0.0, dtype=tf.float32)
for x in tf.unpack(xs):
for k in xrange(self.K):
log_lik += tf.log(pi[k])
log_lik += multivariate_normal.logpdf(x,
mus[(k*self.D):((k+1)*self.D)],
sigmas[(k*self.D):((k+1)*self.D)])
log_lik = []
n_minibatch = get_dims(zs)[0]
for s in xrange(n_minibatch):
log_lik_z = N*tf.reduce_sum(tf.log(pi))
for k in xrange(self.K):
log_lik_z += tf.reduce_sum(multivariate_normal.logpdf(xs,
mus[s, (k*self.D):((k+1)*self.D)],
sigmas[s, (k*self.D):((k+1)*self.D)]))

log_prob += [log_prior + log_lik]
log_lik += [log_lik_z]

return tf.pack(log_prob)
return log_prior + tf.pack(log_lik)

ed.set_seed(42)
x = np.loadtxt('data/mixture_data.txt', dtype='float32', delimiter=',')
Expand Down
3 changes: 1 addition & 2 deletions examples/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def __init__(self, mu, std):
self.num_vars = 1

def log_prob(self, xs, zs):
return tf.pack([norm.logpdf(z, self.mu, self.std)
for z in tf.unpack(zs)])
return norm.logpdf(zs, self.mu, self.std)

ed.set_seed(42)
mu = tf.constant(1.0)
Expand Down
3 changes: 1 addition & 2 deletions examples/normal_idiomatic_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def __init__(self, mu, std):
self.num_vars = 1

def log_prob(self, xs, zs):
return tf.pack([norm.logpdf(z, self.mu, self.std)
for z in tf.unpack(zs)])
return norm.logpdf(zs, self.mu, self.std)

ed.set_seed(42)
mu = tf.constant(1.0)
Expand Down
3 changes: 1 addition & 2 deletions examples/normal_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def __init__(self, mu, Sigma):
self.num_vars = 1

def log_prob(self, xs, zs):
log_prior = tf.pack([norm.logpdf(z, mu, Sigma)
for z in tf.unpack(zs)])
log_prior = norm.logpdf(zs, mu, Sigma)
log_lik = tf.pack([tf.reduce_sum(norm.logpdf(xs, z, Sigma))
for z in tf.unpack(zs)])
return log_lik + log_prior
Expand Down
3 changes: 1 addition & 2 deletions examples/normal_two.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def __init__(self, mu, Sigma):
self.num_vars = get_dims(mu)[0]

def log_prob(self, xs, zs):
return tf.pack([multivariate_normal.logpdf(z, self.mu, self.Sigma)
for z in tf.unpack(zs)])
return multivariate_normal.logpdf(zs, self.mu, self.Sigma)

ed.set_seed(42)
mu = tf.constant([1.0, 1.0])
Expand Down

0 comments on commit cedf2ed

Please sign in to comment.