Skip to content

Commit

Permalink
Change structure of rectifier need to think about a more consistent way
Browse files Browse the repository at this point in the history
  • Loading branch information
mleprovost committed Feb 22, 2024
1 parent a68d5ff commit 149fe25
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 5 deletions.
112 changes: 112 additions & 0 deletions notebooks/Untitled1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"id": "7164cf12",
"metadata": {},
"outputs": [],
"source": [
"using Revise\n",
"using TransportBasedInference\n",
"using Test\n",
"using LinearAlgebra\n",
"using Statistics\n",
"using ForwardDiff"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "fbe9c3f1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[0m\u001b[1mTest Summary: | \u001b[22m\u001b[32m\u001b[1mPass \u001b[22m\u001b[39m\u001b[36m\u001b[1mTotal \u001b[22m\u001b[39m\u001b[0m\u001b[1mTime\u001b[22m\n",
"Rectifier sigmoid_ | \u001b[32m 11 \u001b[39m\u001b[36m 11 \u001b[39m\u001b[0m0.6s\n"
]
},
{
"data": {
"text/plain": [
"Test.DefaultTestSet(\"Rectifier sigmoid_\", Any[], 11, false, false, true, 1.708643370495686e9, 1.708643371109264e9, false, \"In[15]\")"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@testset \"Rectifier sigmoid_\" begin\n",
"\n",
" Kmin = 1e-4\n",
" Kmax = 1e7\n",
" ϵ = 1e-9\n",
"\n",
" r = Rectifier(\"sigmoid_\"; Kmin = Kmin, Kmax = Kmax)\n",
" @test r.T == \"sigmoid_\"\n",
" \n",
" x = 0.4\n",
" @test abs(r(x) - (Kmin + (Kmax - Kmin)*exp(x)/(1+exp(x))))<ϵ\n",
" \n",
" \n",
" # Test gradient\n",
" @test abs(ForwardDiff.derivative(y->r(y), x) - grad_x(r, x) ) < ϵ\n",
" \n",
" # Test hessian\n",
" @test abs(ForwardDiff.derivative(z->ForwardDiff.derivative(y->r(y), z),x) - hess_x(r, x) ) < ϵ\n",
" \n",
" # Test gradient of log evaluation\n",
" @test abs(ForwardDiff.derivative(y->log(r(y)), x) - grad_x_logeval(r, x) ) < ϵ\n",
" \n",
"# # Test hessian of log evaluation\n",
" @test abs(ForwardDiff.hessian(y->log(r(y[1])), [x])[1,1] - hess_x_logeval(r, x) ) < ϵ\n",
" \n",
" \n",
" x = -0.4\n",
" @test abs(r(x) - (Kmin + (Kmax - Kmin)*exp(x)/(1+exp(x))))<ϵ\n",
" \n",
" # Test gradient\n",
" @test abs(ForwardDiff.derivative(y->r(y), x) - grad_x(r, x) ) < ϵ\n",
" \n",
" # Test hessian\n",
" @test abs(ForwardDiff.derivative(z->ForwardDiff.derivative(y->r(y), z),x) - hess_x(r, x) ) < ϵ\n",
" \n",
" # Test gradient of log evaluation\n",
" @test abs(ForwardDiff.derivative(y->log(r(y)), x) - grad_x_logeval(r, x) ) < ϵ\n",
" \n",
" # Test hessian of log evaluation\n",
" @test abs(ForwardDiff.hessian(y->log(r(y[1])), [x])[1,1] - hess_x_logeval(r, x) ) < ϵ\n",
" \n",
"end"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "121ce567",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.10.0",
"language": "julia",
"name": "julia-1.10"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
74 changes: 69 additions & 5 deletions src/hermitemap/rectifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export Rectifier,
square, dsquare, d2square,
softplus, dsoftplus, d2softplus, invsoftplus,
sigmoid, dsigmoid, d2sigmoid, invsigmoid,
sigmoid_, dsigmoid_, d2sigmoid_, invsigmoid_,
explinearunit, dexplinearunit, d2explinearunit, invexplinearunit,
inverse!, inverse, vinverse,
grad_x!, grad_x, vgrad_x,
Expand All @@ -25,8 +26,22 @@ $(TYPEDFIELDS)
"""
struct Rectifier
T::String
Kmin::Union{Nothing, Float64}
Kmax::Union{Nothing, Float64}
function Rectifier(T::String; Kmin = nothing, Kmax = nothing)
if T == "sigmoid_"
@assert Kmin > 0 "Kmin should be > 0 and cannot be nothing"
@assert Kmax > 0 "Kmax should be > 0 and cannot be nothing"
@assert Kmax > Kmin
end
return new(T, Kmin, Kmax)
end
end



const KMIN = 1e-3
const KMAX = 100
square(x) = x^2
dsquare(x) = 2.0*x
d2square(x) = 2.0
Expand All @@ -38,25 +53,45 @@ d2softplus(x) = log(2.0)/(2.0*(1.0 + cosh(log(2.0)*x)))
invsoftplus(x) = min(log(exp(log(2.0)*x) - 1.0)/log(2.0), x)

# Logistic tools
# Sigmoid implementation from NNlib.jl to avoid underflow errors
# Sigmoid implementation from NNlib.jl to avoid underflow errors.

function sigmoid(x)
t = exp(-abs(x))
ifelse(x 0, inv(1 + t), t / (1 + t))
end

function dsigmoid(x)
σ = sigmoid(x)
return σ*(1-σ)
end

function d2sigmoid(x)
σ = sigmoid(x)
# from dσ*(1-σ) - σ*dσ
return σ*(1-σ)*(1-2*σ)
end
invsigmoid(x) = ifelse(x > 0, log(x) - log(1-x), "Not defined for x ≤ 0 ")

function sigmoid_(x, K_min, K_max)
return K_min + (K_max-K_min) * sigmoid(x)
end

function dsigmoid_(x, K_min, K_max)
σ = sigmoid(x)
return (K_max-K_min)*σ*(1-σ)
end

function d2sigmoid_(x, K_min, K_max)
σ = sigmoid(x)
return (K_max-K_min) * σ*(1-σ)*(1-2*σ)
end

function invsigmoid_(x, K_min, K_max)
if x > K_min && x < K_max
return log(x-K_min) - log(K_max-x)
else
return "Not defined for x outside [K_min, K_max]"
end
end

explinearunit(x) = x < 0.0 ? exp(x) : x + 1.0
dexplinearunit(x) = x < 0.0 ? exp(x) : 1.0
Expand All @@ -76,6 +111,8 @@ function (g::Rectifier)(x)
return exp(x)
elseif g.T=="sigmoid"
return sigmoid(x)
elseif g.T=="sigmoid_"
return sigmoid_(x, g.Kmin, g.Kmax)
elseif g.T=="softplus"
return softplus(x)
elseif g.T=="explinearunit"
Expand All @@ -94,6 +131,9 @@ function evaluate!(result, g::Rectifier, x)
elseif g.T=="sigmoid"
vmap!(sigmoid, result, x)
return result
elseif g.T=="sigmoid_"
vmap!(y -> sigmoid_(y, g.Kmin, g.Kmax), result, x)
return result
elseif g.T=="softplus"
vmap!(softplus, result, x)
return result
Expand All @@ -103,7 +143,7 @@ function evaluate!(result, g::Rectifier, x)
end
end

vevaluate(g::Rectifier, x) = evaluate!(zero(x), g, x)
evaluate(g::Rectifier, x) = evaluate!(zero(x), g, x)

function inverse(g::Rectifier, x)
@assert x>=0 "Input to rectifier is negative"
Expand All @@ -113,6 +153,8 @@ function inverse(g::Rectifier, x)
return log(x)
elseif g.T=="sigmoid"
return invsigmoid(x)
elseif g.T=="sigmoid_"
return invsigmoid_(x, g.Kmin, g.Kmax)
elseif g.T=="softplus"
return invsoftplus(x)
elseif g.T=="explinearunit"
Expand All @@ -131,6 +173,8 @@ function inverse!(result, g::Rectifier, x)
elseif g.T=="sigmoid"
vmap!(invsigmoid, result, x)
return result
elseif g.T=="sigmoid_"
vmap!(y->invsigmoid(y, g.Kmin, g.Kmax), result, x)
elseif g.T=="softplus"
vmap!(invsoftplus, result, x)
return result
Expand All @@ -150,6 +194,8 @@ function grad_x(g::Rectifier, x)
return exp(x)
elseif g.T=="sigmoid"
return dsigmoid(x)
elseif g.T=="sigmoid_"
return dsigmoid_(x, g.Kmin, g.Kmax)
elseif g.T=="softplus"
return dsoftplus(x)
elseif g.T=="explinearunit"
Expand All @@ -159,7 +205,7 @@ end


function grad_x!(result, g::Rectifier, x)
@assert size(result,1) == size(x,1) "Dimension of result and x don't match"
@assert size(result,1) == size(x, 1) "Dimension of result and x don't match"
if g.T=="squared"
vmap!(dsquare, result, x)
return result
Expand All @@ -169,6 +215,9 @@ function grad_x!(result, g::Rectifier, x)
elseif g.T=="sigmoid"
vmap!(dsigmoid, result, x)
return result
elseif g.T=="sigmoid_"
vmap!(y->dsigmoid_(y, g.Kmin, g.Kmax), result, x)
return result
elseif g.T=="softplus"
vmap!(dsoftplus, result, x)
return result
Expand All @@ -187,7 +236,9 @@ function grad_x_logeval(g::Rectifier, x::T) where {T <: Real}
elseif g.T=="exponential"
return 1.0
elseif g.T=="sigmoid"
return dsigmoid(x)/sigmoid(x)
return dsigmoid(x)/sigmoid(x)
elseif g.T=="sigmoid_"
return dsigmoid_(x, g.Kmin, g.Kmax) / sigmoid_(x, g.Kmin, g.Kmax)
elseif g.T=="softplus"
return dsoftplus(x)/softplus(x)
elseif g.T=="explinearunit"
Expand All @@ -206,6 +257,9 @@ function grad_x_logeval!(result, g::Rectifier, x)
elseif g.T=="sigmoid"
vmap!(xi->dsigmoid(xi)/sigmoid(xi), result, x)
return result
elseif g.T=="sigmoid_"
vmap!(xi->dsigmoid_(xi, g.Kmin, g.Kmax)/sigmoid_(xi, g.Kmin, g.Kmax), result, x)
return result
elseif g.T=="softplus"
vmap!(xi->dsoftplus(xi)/softplus(xi), result, x)
return result
Expand All @@ -226,6 +280,8 @@ function hess_x_logeval(g::Rectifier, x::T) where {T <: Real}
return 0.0
elseif g.T=="sigmoid"
return (d2sigmoid(x)*sigmoid(x) - dsigmoid(x)^2)/sigmoid(x)^2
elseif g.T=="sigmoid_"
return (d2sigmoid_(x, g.Kmin, g.Kmax)*sigmoid_(x, g.Kmin, g.Kmax) - dsigmoid_(x, g.Kmin, g.Kmax)^2) / sigmoid_(x, g.Kmin, g.Kmax)^2
elseif g.T=="softplus"
return (d2softplus(x)*softplus(x) - dsoftplus(x)^2)/softplus(x)^2
elseif g.T=="explinearunit"
Expand All @@ -244,6 +300,9 @@ function hess_x_logeval!(result, g::Rectifier, x)
elseif g.T=="sigmoid"
vmap!(xi->(d2sigmoid(xi)*sigmoid(xi) - dsigmoid(xi)^2)/sigmoid(xi)^2, result, x)
return result
elseif g.T=="sigmoid_"
vmap!(xi->(d2sigmoid_(xi, g.Kmin, g.Kmax)*sigmoid_(xi, g.Kmin, g.Kmax) - dsigmoid_(xi, g.Kmin, g.Kmax)^2)/sigmoid_(xi, g.Kmin, g.Kmax)^2, result, x)
return result
elseif g.T=="softplus"
vmap!(xi->(d2softplus(xi)*softplus(xi) - dsoftplus(xi)^2)/softplus(xi)^2, result, x)
return result
Expand All @@ -262,6 +321,8 @@ function hess_x(g::Rectifier, x::T) where {T <: Real}
return exp(x)
elseif g.T=="sigmoid"
return d2sigmoid(x)
elseif g.T=="sigmoid_"
return d2sigmoid_(x, g.Kmin, g.Kmax)
elseif g.T=="softplus"
return d2softplus(x)
elseif g.T=="explinearunit"
Expand All @@ -280,6 +341,9 @@ function hess_x!(result, g::Rectifier, x)
elseif g.T=="sigmoid"
vmap!(d2softplus, result, x)
return result
elseif g.T=="sigmoid_"
vmap!(y->d2sigmoid_(y, g.Kmin, g.Kmax), result, x)
return result
elseif g.T=="softplus"
vmap!(d2softplus, result, x)
return result
Expand Down
43 changes: 43 additions & 0 deletions test/hermitemap/rectifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,49 @@ x = -0.4

end

@testset "Rectifier sigmoid_" begin

Kmin = 1e-4
Kmax = 1e7
ϵ = 1e-9

r = Rectifier("sigmoid_"; Kmin = Kmin, Kmax = Kmax)
@test r.T == "sigmoid_"

x = 0.4
@test abs(r(x) - (Kmin + (Kmax - Kmin)*exp(x)/(1+exp(x))))<ϵ


# Test gradient
@test abs(ForwardDiff.derivative(y->r(y), x) - grad_x(r, x) ) < ϵ

# Test hessian
@test abs(ForwardDiff.derivative(z->ForwardDiff.derivative(y->r(y), z),x) - hess_x(r, x) ) < ϵ

# Test gradient of log evaluation
@test abs(ForwardDiff.derivative(y->log(r(y)), x) - grad_x_logeval(r, x) ) < ϵ

# # Test hessian of log evaluation
@test abs(ForwardDiff.hessian(y->log(r(y[1])), [x])[1,1] - hess_x_logeval(r, x) ) < ϵ


x = -0.4
@test abs(r(x) - (Kmin + (Kmax - Kmin)*exp(x)/(1+exp(x))))<ϵ

# Test gradient
@test abs(ForwardDiff.derivative(y->r(y), x) - grad_x(r, x) ) < ϵ

# Test hessian
@test abs(ForwardDiff.derivative(z->ForwardDiff.derivative(y->r(y), z),x) - hess_x(r, x) ) < ϵ

# Test gradient of log evaluation
@test abs(ForwardDiff.derivative(y->log(r(y)), x) - grad_x_logeval(r, x) ) < ϵ

# Test hessian of log evaluation
@test abs(ForwardDiff.hessian(y->log(r(y[1])), [x])[1,1] - hess_x_logeval(r, x) ) < ϵ

end


@testset "Rectifier softplus" begin

Expand Down

0 comments on commit 149fe25

Please sign in to comment.