Skip to content

Commit

Permalink
Merge #787
Browse files Browse the repository at this point in the history
787: Fix up Euclidean distance at the origin r=willtebbutt a=willtebbutt

(Hopefully) fixes some Euclidean distance issues when the inputs are close together.

Fixes JuliaGaussianProcesses/KernelFunctions.jl#166

@molet could you try this out and see whether or not it fixes your issue?

I've introduced `FiniteDifferences` as a test dep because Zygote's own finite differencing isn't sufficiently accurate in this case. I'm not completely sure why.

Co-authored-by: wt <[email protected]>
Co-authored-by: willtebbutt <[email protected]>
  • Loading branch information
3 people authored Sep 11, 2020
2 parents b269ed3 + 664b05a commit ed366f3
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.5.6"
version = "0.5.7"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
26 changes: 15 additions & 11 deletions src/lib/distances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ end
end

@adjoint function (::Euclidean)(x::AbstractVector, y::AbstractVector)
D = x.-y
D = x .- y
δ = sqrt(sum(abs2, D))
function euclidean::Real)
= (Δ / δ) .* D
= ifelse(iszero(δ), D, (Δ / δ) .* D)
return x̄, -
end
return δ, euclidean
Expand All @@ -59,26 +59,30 @@ end
@adjoint function colwise(s::Euclidean, x::AbstractMatrix, y::AbstractMatrix)
d = colwise(s, x, y)
return d, function::AbstractVector)
=./ d)' .* (x .- y)
=./ max.(d, eps(eltype(d))))' .* (x .- y)
return nothing, x̄, -
end
end

@adjoint function pairwise(::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
D, back = pullback(
(X, Y) -> pairwise(SqEuclidean(), X, Y; dims = dims),
X,
Y,
)
D .= sqrt.(D)
return D, Δ -> (nothing, back./ (2 .* D))...)

# Modify the forwards-pass slightly to ensure stability on the reverse.
function _pairwise_euclidean(X, Y)
δ = eps(promote_type(eltype(X), eltype(Y)))^2
return sqrt.(max.(pairwise(SqEuclidean(), X, Y; dims=dims), δ))
end
D, back = pullback(_pairwise_euclidean, X, Y)

return D, function(Δ)
return (nothing, back(Δ)...)
end
end

@adjoint function pairwise(::Euclidean, X::AbstractMatrix; dims=2)
D, back = pullback(X -> pairwise(SqEuclidean(), X; dims = dims), X)
D .= sqrt.(D)
return D, function(Δ)
Δ = Δ ./ (2 .* D)
Δ = Δ ./ (2 .* max.(D, eps(eltype(D))))
Δ[diagind(Δ)] .= 0
return (nothing, first(back(Δ)))
end
Expand Down
5 changes: 3 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Expand Down
56 changes: 37 additions & 19 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using NNlib: conv, ∇conv_data, depthwiseconv, batched_mul
using Base.Broadcast: broadcast_shape
using LoopVectorization: vmap
using Distributed: pmap
using FiniteDifferences

function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
Expand All @@ -21,9 +22,11 @@ function ngradient(f, xs::AbstractArray...)
return grads
end

gradcheck(f, xs...) =
all(isapprox.(ngradient(f, xs...),
gradient(f, xs...), rtol = 1e-5, atol = 1e-5))
function gradcheck(f, xs...)
grad_zygote = gradient(f, xs...)
grad_finite_difference = ngradient(f, xs...)
return all(isapprox.(grad_zygote, grad_finite_difference; rtol = 1e-5, atol = 1e-5))
end

gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
Expand Down Expand Up @@ -1059,41 +1062,55 @@ end
end

@testset "distances" begin
rng, P, Q, D = MersenneTwister(123456), 10, 9, 8
rng, P, Q, D = MersenneTwister(123456), 5, 4, 3

for (f, metric) in ((euclidean, Euclidean()), (sqeuclidean, SqEuclidean()))
let

@testset "scalar input" begin
x, y = randn(rng), randn(rng)
@test gradtest(x -> f(x[1], y), [x])
@test gradtest(x -> evaluate(metric, x[1], y), [x])
@test gradtest(y -> f(x, y[1]), [y])
@test gradtest(y -> evaluate(metric, x, y[1]), [y])
end

let
@testset "vector input" begin
x, y = randn(rng, D), randn(rng, D)
@test gradtest(x -> f(x, y), x)
@test gradtest(x -> evaluate(metric, x, y), x)
@test gradtest(y -> f(x, y), y)
@test gradtest(y -> evaluate(metric, x, y), y)
@test gradtest(x -> f(x, x), x)
end

# Check binary colwise.
let
@testset "binary colwise" begin
X, Y = randn(rng, D, P), randn(rng, D, P)
@test gradtest(X->colwise(metric, X, Y), X)
@test gradtest(Y->colwise(metric, X, Y), Y)
@test gradtest(X -> colwise(metric, X, Y), X)
@test gradtest(Y -> colwise(metric, X, Y), Y)
@test gradtest(X -> colwise(metric, X, X), X)
end

# Check binary pairwise.
let
@testset "binary pairwise" begin
X, Y = randn(rng, D, P), randn(rng, D, Q)
@test gradtest(X->pairwise(metric, X, Y; dims=2), X)
@test gradtest(Y->pairwise(metric, X, Y; dims=2), Y)
@test gradtest(X -> pairwise(metric, X, Y; dims=2), X)
@test gradtest(Y -> pairwise(metric, X, Y; dims=2), Y)

@testset "X == Y" begin
# Zygote's gradtest isn't sufficiently accurate to assess this, so we use
# FiniteDifferences.jl instead.
Y = copy(X)
Δ = randn(P, P)
Δ_fd = FiniteDifferences.j′vp(
central_fdm(5, 1), X -> pairwise(metric, X, Y; dims=2), Δ, X,
)
_, pb = Zygote.pullback(X -> pairwise(metric, X, Y; dims=2), X)

# This is impressively inaccurate, but at least it doesn't produce a NaN.
@test first(Δ_fd) first(pb(Δ)) atol=1e-3 rtol=1e-3
end
end

# Check binary pairwise when X and Y are close.
let
@testset "binary pairwise - X and Y close" begin
X = randn(rng, D, P)
Y = X .+ 1e-10
dist = pairwise(metric, X, Y; dims=2)
Expand All @@ -1106,9 +1123,10 @@ end
@test gradtest(Yt->pairwise(metric, Xt, Yt; dims=1), Yt)
end

# Check unary pairwise.
@test gradtest(X->pairwise(metric, X; dims=2), randn(rng, D, P))
@test gradtest(Xt->pairwise(metric, Xt; dims=1), randn(rng, P, D))
@testset "unary pairwise" begin
@test gradtest(X->pairwise(metric, X; dims=2), randn(rng, D, P))
@test gradtest(Xt->pairwise(metric, Xt; dims=1), randn(rng, P, D))
end
end
end

Expand Down

2 comments on commit ed366f3

@DhairyaLGandhi
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/21444

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.7 -m "<description of version>" ed366f32c0f520567526040d9f8acaf0d83613c3
git push origin v0.5.7

Please sign in to comment.