diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e28f11a0b9..c57d5c6e3c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -191,8 +191,8 @@ Dense(W::LinearAlgebra.Diagonal, bias = true, σ = identity) = Scale(W.diag, bias, σ) function _size_check(layer, x::AbstractArray, (d, n)::Pair) - d > 0 || throw(DimensionMismatch(string("layer ", layer, - " expects ndims(input) > ", ndims(x)-d, ", but got ", summary(x)))) + d <= ndims(x) || throw(DimensionMismatch(string("layer ", layer, + " expects ndims(input) >= ", d, ", but got ", summary(x)))) size(x, d) == n || throw(DimensionMismatch(string("layer ", layer, lazy" expects size(input, $d) == $n, but got ", summary(x)))) end