Skip to content

Commit

Permalink
add matrix inputs to uniform; add matrix input unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed May 14, 2016
1 parent 8db654d commit cc1ef0b
Show file tree
Hide file tree
Showing 16 changed files with 81 additions and 7 deletions.
14 changes: 7 additions & 7 deletions edward/stats/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ def logpmf(self, x):
Arguments
---------
x: np.array or tf.Tensor
If univariate distribution, can be a scalar or vector.
If univariate distribution, can be a scalar, vector, or matrix.
If multivariate distribution, can be a vector or matrix.
params: np.array or tf.Tensor
Returns
-------
tf.Tensor
For univariate distributions, scalar if scalar input and
vector if vector input. For multivariate distributions,
scalar if vector input and vector if matrix input, where
the each element in the vector evaluates a row in the
matrix.
For univariate distributions, returns a scalar, vector, or
matrix corresponding to the size of input. For
multivariate distributions, returns a scalar if vector
input and vector if matrix input, where each element in
the vector evaluates a row in the matrix.
Note
----
Expand Down Expand Up @@ -357,7 +357,7 @@ def rvs(self, loc=0, scale=1, size=1):
def logpdf(self, x, loc=0, scale=1):
# Note there is no error checking if x is outside domain.
scale = tf.cast(tf.squeeze(scale), dtype=tf.float32)
return tf.squeeze(tf.ones([get_dims(x)[0]]) * -tf.log(scale))
return tf.squeeze(tf.ones(get_dims(x)) * -tf.log(scale))

bernoulli = Bernoulli()
beta = Beta()
Expand Down
8 changes: 8 additions & 0 deletions tests/test_stats_bernoulli_logpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,11 @@ def test_logpmf_int_1d():
def test_logpmf_float_1d():
_test_logpmf([0.0, 1.0, 0.0], 0.5)
_test_logpmf([1.0, 0.0, 0.0], 0.75)

def test_logpmf_int_2d():
_test_logpmf(np.array([[0, 1, 0],[0, 1, 0]]), 0.5)
_test_logpmf(np.array([[1, 0, 0],[0, 1, 0]]), 0.75)

def test_logpmf_float_2d():
_test_logpmf(np.array([[0.0, 1.0, 0.0],[0.0, 1.0, 0.0]]), 0.5)
_test_logpmf(np.array([[1.0, 0.0, 0.0],[0.0, 1.0, 0.0]]), 0.75)
3 changes: 3 additions & 0 deletions tests/test_stats_beta_logpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ def test_logpdf_scalar():

def test_logpdf_1d():
_test_logpdf([0.5, 0.3, 0.8, 0.1], a=0.5, b=0.5)

def test_logpdf_2d():
_test_logpdf(np.array([[0.5, 0.3, 0.8, 0.1],[0.5, 0.3, 0.8, 0.1]]))
8 changes: 8 additions & 0 deletions tests/test_stats_binom_logpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,11 @@ def test_logpmf_int_1d():
def test_logpmf_float_1d():
_test_logpmf([0.0, 1.0, 0.0], 1, 0.5)
_test_logpmf([1.0, 0.0, 0.0], 1, 0.75)

def test_logpmf_int_2d():
_test_logpmf(np.array([[0, 1, 0],[1, 0, 0]]), 1, 0.5)
_test_logpmf(np.array([[1, 0, 0],[0, 1, 0]]), 1, 0.75)

def test_logpmf_float_2d():
_test_logpmf(np.array([[0.0, 1.0, 0.0],[1.0, 0.0, 0.0]]), 1, 0.5)
_test_logpmf(np.array([[1.0, 0.0, 0.0],[0.0, 1.0, 0.0]]), 1, 0.75)
3 changes: 3 additions & 0 deletions tests/test_stats_chi2_logpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ def test_logpdf_scalar():

def test_logpdf_1d():
_test_logpdf([0.1, 1.0, 0.58, 2.3], df=3)

def test_logpdf_2d():
_test_logpdf(np.array([[0.1, 1.0, 0.58, 2.3],[0.3, 1.1, 0.68, 1.2]]), df=3)
3 changes: 3 additions & 0 deletions tests/test_stats_expon_logpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ def test_logpdf_scalar():

def test_logpdf_1d():
_test_logpdf([0.5, 2.3, 5.8, 10.1], scale=5.0)

def test_logpdf_2d():
_test_logpdf(np.array([[0.5, 2.3, 5.8, 10.1],[0.5, 2.3, 5.8, 10.1]]), scale=5.0)
4 changes: 4 additions & 0 deletions tests/test_stats_gamma_logpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ def test_logpdf_scalar():

def test_logpdf_1d():
_test_logpdf([0.5, 1.2, 5.3, 8.7], a=0.5, b=0.5)

def test_logpdf_2d():
_test_logpdf(np.array([[0.5, 1.2, 5.3, 8.7],[0.5, 1.2, 5.3, 8.7]]),
a=0.5, b=0.5)
8 changes: 8 additions & 0 deletions tests/test_stats_geom_logpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,11 @@ def test_logpmf_int_1d():
def test_logpmf_float_1d():
_test_logpmf([1.0, 5.0, 3.0], 0.5)
_test_logpmf([2.0, 8.0, 2.0], 0.75)

def test_logpmf_int_2d():
_test_logpmf(np.array([[1, 5, 3],[2, 8, 2]]), 0.5)
_test_logpmf(np.array([[2, 8, 2],[1, 5, 3]]), 0.75)

def test_logpmf_float_2d():
_test_logpmf(np.array([[1.0, 5.0, 3.0],[2.0, 8.0, 2.0]]), 0.5)
_test_logpmf(np.array([[2.0, 8.0, 2.0],[1.0, 5.0, 3.0]]), 0.75)
4 changes: 4 additions & 0 deletions tests/test_stats_invgamma_logpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ def test_logpdf_scalar():

def test_logpdf_1d():
_test_logpdf([0.5, 1.2, 5.3, 8.7], a=0.5, scale=0.5)

def test_logpdf_2d():
_test_logpdf(np.array([[0.5, 1.2, 5.3, 8.7],[0.5, 1.2, 5.3, 8.7]]),
a=0.5, scale=0.5)
3 changes: 3 additions & 0 deletions tests/test_stats_lognorm_logpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ def test_logpdf_scalar():

def test_logpdf_1d():
_test_logpdf([2.0, 1.0, 0.58, 2.3])

def test_logpdf_2d():
_test_logpdf(np.array([[2.0, 1.0, 0.58, 2.3],[2.1, 1.3, 1.58, 0.3]]))
8 changes: 8 additions & 0 deletions tests/test_stats_nbinom_logpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,11 @@ def test_logpmf_int_1d():
def test_logpmf_float_1d():
_test_logpmf([1.0, 5.0, 3.0], 5, 0.5)
_test_logpmf([2.0, 8.0, 2.0], 5, 0.75)

def test_logpmf_int_2d():
_test_logpmf(np.array([[1, 5, 3],[2, 8, 2]]), 5, 0.5)
_test_logpmf(np.array([[2, 8, 2],[1, 5, 3]]), 5, 0.75)

def test_logpmf_float_2d():
_test_logpmf(np.array([[1.0, 5.0, 3.0],[2.0, 8.0, 2.0]]), 5, 0.5)
_test_logpmf(np.array([[2.0, 8.0, 2.0],[1.0, 5.0, 3.0]]), 5, 0.75)
3 changes: 3 additions & 0 deletions tests/test_stats_norm_logpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ def test_logpdf_scalar():

def test_logpdf_1d():
_test_logpdf([0.0, 1.0, 0.58, 2.3])

def test_logpdf_2d():
_test_logpdf(np.array([[0.0, 1.0, 0.58, 2.3], [0.1, 1.5, 4.18, 0.3]]))
8 changes: 8 additions & 0 deletions tests/test_stats_poisson_logpmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,11 @@ def test_logpmf_int_1d():
def test_logpmf_float_1d():
_test_logpmf([0.0, 1.0, 3.0], 0.5)
_test_logpmf([1.0, 8.0, 2.0], 0.75)

def test_logpmf_int_2d():
_test_logpmf(np.array([[0, 1, 3],[1, 8, 2]]), 0.5)
_test_logpmf(np.array([[1, 8, 2],[0, 1, 3]]), 0.75)

def test_logpmf_float_2d():
_test_logpmf(np.array([[0.0, 1.0, 3.0],[0.0, 1.0, 3.0]]), 0.5)
_test_logpmf(np.array([[1.0, 8.0, 2.0],[1.0, 8.0, 2.0]]), 0.75)
3 changes: 3 additions & 0 deletions tests/test_stats_t_logpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ def test_logpdf_scalar():

def test_logpdf_1d():
_test_logpdf([0.0, 1.0, 0.58, 2.3], df=3)

def test_logpdf_2d():
_test_logpdf(np.array([[0.0, 1.0, 0.58, 2.3],[0.0, 1.0, 0.58, 2.3]]), df=3)
4 changes: 4 additions & 0 deletions tests/test_stats_truncnorm_logpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ def test_logpdf_scalar():

def test_logpdf_1d():
_test_logpdf([0.0, 1.0, 0.58, 2.3], a=-1.0, b=3.0)

def test_logpdf_2d():
_test_logpdf(np.array([[0.0, 1.0, 0.58, 2.3],[0.0, 1.0, 0.58, 2.3]]),
a=-1.0, b=3.0)
4 changes: 4 additions & 0 deletions tests/test_stats_uniform_logpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@ def test_logpdf_scalar():

def test_logpdf_1d():
_test_logpdf([0.5, 0.3, 0.8, 0.2], loc=0.1, scale=0.9)

def test_logpdf_2d():
_test_logpdf(np.array([[0.5, 0.3, 0.8, 0.2],[0.5, 0.3, 0.8, 0.2]]),
loc=0.1, scale=0.9)

0 comments on commit cc1ef0b

Please sign in to comment.