Skip to content

Commit

Permalink
Added VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed Apr 2, 2024
1 parent fd959b7 commit 59118a4
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 46 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
ParticleFilters = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

Expand Down
12 changes: 11 additions & 1 deletion docs/src/compressors.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ CompressedBeliefMDPs currently provides wrappers for the following compression t
- a factor analysis compressor,
- an isomap compressor,
- an autoencoder compressor
- a variational auto-encoder (VAE) compressor

## Principal Component Analysis (PCA)
```@docs
Expand All @@ -29,10 +30,19 @@ FactorAnalysisCompressor
```

## Isomap

```@docs
IsomapCompressor
```

## Autoencoder
```@docs
AutoencoderCompressor
```

### Variational Auto-Encoder (VAE)
```@docs
VAECompressor
```

!!! warning
Some compression algorithms aren't optimized for large belief spaces. While they pass our unit tests, they may fail on large POMDPs or without seeding. For large POMDPs, users may want a custom `Compressor`.
2 changes: 1 addition & 1 deletion docs/src/samplers.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Implemented Sampler

CompressedBeliefMDPs provides the following generic belief samplers:
- a exploratory belief expansion sampler
- an exploratory belief expansion sampler
- a [Policy](https://juliapomdp.github.io/POMDPs.jl/latest/api/#POMDPs.Policy) rollout sampler
- an [ExplorationPolicy](https://juliapomdp.github.io/POMDPs.jl/latest/POMDPTools/policies/#Exploration-Policies) rollout sampler

Expand Down
5 changes: 4 additions & 1 deletion src/CompressedBeliefMDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ export
ManifoldCompressor,
IsomapCompressor,
### Flux compressors ###
AutoencoderCompressor
AutoencoderCompressor,
VAECompressor
include("compressors/compressor.jl")
include("compressors/mvs_compressors.jl")
include("compressors/manifold_compressors.jl")
include("compressors/autoencoders.jl")
include("compressors/vae.jl")


export
Sampler,
Expand Down
46 changes: 4 additions & 42 deletions src/compressors/autoencoders.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
using Flux


struct AutoencoderCompressor <: Compressor
encoder
model
optimizer
epochs
end

"""
Implements an autoencoder in Flux.
"""
function AutoencoderCompressor(input_dim::Integer, latent_dim::Integer; opt=Adam(), epochs=10)
encoder = Dense(input_dim, latent_dim, sigmoid) |> f64
decoder = Chain(Dense(latent_dim => input_dim), softmax)
Expand All @@ -18,7 +22,6 @@ function fit!(c::AutoencoderCompressor, beliefs)
opt_state = Flux.setup(c.optimizer, c.model)
data = [(beliefs', beliefs')]
loss(m, x, y) = Flux.kldivergence(m(x), y)
# @showprogress for _ in 1:c.epochs
for _ in 1:c.epochs
Flux.train!(loss, c.model, data, opt_state)
end
Expand All @@ -27,44 +30,3 @@ end
function (c::AutoencoderCompressor)(beliefs)
return ndims(beliefs) == 2 ? c.encoder(beliefs')' : c.encoder(beliefs)
end


# struct VAECompressor <: Compressor
# encoder
# model
# optimizer
# epochs
# end

# # custom split layer from https://fluxml.ai/Flux.jl/dev/models/advanced/#Multiple-outputs:-a-custom-Split-layer
# struct Split{T}
# paths::T
# end

# Split(paths...) = Split(paths)

# Flux.@layer Split

# (m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)


# "
# Adapted from: https://github.com/FlyingWorkshop/DiffusionGNNTutorial
# "
# function VAECompressor(input_dim::Integer, hidden_dim::Integer=1, latent_dim::Integer; opt=Adam(), epochs=10)
# encoder = Chain(
# Dense(input_dim => hidden_dim, relu),
# Split(Dense(hidden_dim => latent_dim), Dense(hidden_dim => latent_dim))
# ) |> f64

# function model(x)
# μ, σ = encoder(x)
# ϵ = randn(size(μ)...)

# end

# encoder = Dense(input_dim, latent_dim, sigmoid) |> f64
# decoder = Chain(Dense(latent_dim => input_dim), softmax)
# model = Chain(encoder, decoder) |> f64
# return AutoencoderCompressor(encoder, model, opt, epochs)
# end
103 changes: 103 additions & 0 deletions src/compressors/vae.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
using Flux


"
Adapted from:
- https://github.com/FlyingWorkshop/DiffusionGNNTutorial
- https://github.com/FluxML/model-zoo/tree/master/vision/vae_mnist
"


struct Encoder
linear
μ
logσ
end

Flux.@layer Encoder

Encoder(input_dim::Integer, latent_dim::Integer, hidden_dim::Integer) = Encoder(
Dense(input_dim, hidden_dim, tanh), # linear
Dense(hidden_dim, latent_dim), # μ
Dense(hidden_dim, latent_dim), # logσ
) |> f64

function (encoder::Encoder)(x)
h = encoder.linear(x)
encoder.μ(h), encoder.logσ(h)
end

Decoder(input_dim::Integer, latent_dim::Integer, hidden_dim::Integer) = Chain(
Dense(latent_dim, hidden_dim, tanh),
Dense(hidden_dim, input_dim)
) |> f64

function reconstuct(encoder, decoder, x)
μ, logσ = encoder(x)
z = μ + randn(size(logσ)...) .* exp.(logσ)
return μ, logσ, decoder(z)
end

function model_loss(encoder, decoder, x)
μ, logσ, decoder_z = reconstuct(encoder, decoder, x)
kl_q_p = 0.5f0 * sum(@. (exp(2logσ) + μ^2 - 1 - 2logσ))
logp_x_z = -Flux.logitbinarycrossentropy(decoder_z, x, agg=sum)
return -logp_x_z + kl_q_p
end

struct VAECompressor <: Compressor
encoder
decoder
optimizer
epochs
verbose
end


"""
Implements a [VAE](https://arxiv.org/abs/1312.6114) in Flux.
"""
function VAECompressor(input_dim::Integer, latent_dim::Integer; hidden_dim::Integer=2, optimizer=Adam(), epochs::Integer=10, verbose=false)
encoder = Encoder(input_dim, latent_dim, hidden_dim)
decoder = Decoder(input_dim, latent_dim, hidden_dim)
VAECompressor(encoder, decoder, optimizer, epochs, verbose)
end

function fit!(c::VAECompressor, beliefs)
encoder, decoder = c.encoder, c.decoder
opt_enc = Flux.setup(c.optimizer, encoder)
opt_dec = Flux.setup(c.optimizer, decoder)

if c.verbose
println("Start Training, total $(c.epochs) epochs")
end

for epoch = 1:c.epochs
if c.verbose
println("Epoch $(epoch)")
end

for b in eachrow(beliefs)
loss, (grad_enc, grad_dec) = Flux.withgradient(encoder, decoder) do enc, dec
model_loss(enc, dec, b)
end

Flux.update!(opt_enc, encoder, grad_enc)
Flux.update!(opt_dec, decoder, grad_dec)

# progress meter
if c.verbose
@show loss
end
end
end
end

function (c::VAECompressor)(beliefs)
if ndims(beliefs) == 2
= c.encoder(beliefs')[1]'
else
= c.encoder(beliefs)[1]
end
return
end
1 change: 1 addition & 0 deletions test/compressor_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ MANIFOLD_COMPRESSORS = (

FLUX_COMPRESSORS = (
AutoencoderCompressor,
VAECompressor
)

@testset "Compressor Tests" begin
Expand Down

0 comments on commit 59118a4

Please sign in to comment.