From f933bff22dca8e7054dbeaa074e35d36b338af41 Mon Sep 17 00:00:00 2001 From: Christian Gruber Date: Fri, 26 Jul 2024 17:13:44 +0200 Subject: [PATCH] Fix function _size_check() --- src/layers/basic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e28f11a0b9..c57d5c6e3c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -191,8 +191,8 @@ Dense(W::LinearAlgebra.Diagonal, bias = true, σ = identity) = Scale(W.diag, bias, σ) function _size_check(layer, x::AbstractArray, (d, n)::Pair) - d > 0 || throw(DimensionMismatch(string("layer ", layer, - " expects ndims(input) > ", ndims(x)-d, ", but got ", summary(x)))) + d <= ndims(x) || throw(DimensionMismatch(string("layer ", layer, + " expects ndims(input) >= ", d, ", but got ", summary(x)))) size(x, d) == n || throw(DimensionMismatch(string("layer ", layer, lazy" expects size(input, $d) == $n, but got ", summary(x)))) end