From 59118a49b32d9fce9db6b165dc8af7f04a8e9dfb Mon Sep 17 00:00:00 2001 From: Logan Mondal Bhamidipaty <76822456+FlyingWorkshop@users.noreply.github.com> Date: Mon, 1 Apr 2024 18:34:56 -0700 Subject: [PATCH] Added VAE --- Project.toml | 1 - docs/src/compressors.md | 12 +++- docs/src/samplers.md | 2 +- src/CompressedBeliefMDPs.jl | 5 +- src/compressors/autoencoders.jl | 46 ++------------ src/compressors/vae.jl | 103 ++++++++++++++++++++++++++++++++ test/compressor_tests.jl | 1 + 7 files changed, 124 insertions(+), 46 deletions(-) create mode 100644 src/compressors/vae.jl diff --git a/Project.toml b/Project.toml index c9a24f4..c03fcc2 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,6 @@ POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7" POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" ParticleFilters = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0" -ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/docs/src/compressors.md b/docs/src/compressors.md index 26c6984..1ea59b6 100644 --- a/docs/src/compressors.md +++ b/docs/src/compressors.md @@ -7,6 +7,7 @@ CompressedBeliefMDPs currently provides wrappers for the following compression t - a factor analysis compressor, - an isomap compressor, - an autoencoder compressor +- a variational auto-encoder (VAE) compressor ## Principal Component Analysis (PCA) ```@docs @@ -29,10 +30,19 @@ FactorAnalysisCompressor ``` ## Isomap - ```@docs IsomapCompressor ``` +## Autoencoder +```@docs +AutoencoderCompressor +``` + +### Variational Auto-Encoder (VAE) +```@docs +VAECompressor +``` + !!! warning Some compression algorithms aren't optimized for large belief spaces. While they pass our unit tests, they may fail on large POMDPs or without seeding. For large POMDPs, users may want a custom `Compressor`. \ No newline at end of file diff --git a/docs/src/samplers.md b/docs/src/samplers.md index 94bcb94..54e8eff 100644 --- a/docs/src/samplers.md +++ b/docs/src/samplers.md @@ -1,7 +1,7 @@ # Implemented Sampler CompressedBeliefMDPs provides the following generic belief samplers: -- a exploratory belief expansion sampler +- an exploratory belief expansion sampler - a [Policy](https://juliapomdp.github.io/POMDPs.jl/latest/api/#POMDPs.Policy) rollout sampler - an [ExplorationPolicy](https://juliapomdp.github.io/POMDPs.jl/latest/POMDPTools/policies/#Exploration-Policies) rollout sampler diff --git a/src/CompressedBeliefMDPs.jl b/src/CompressedBeliefMDPs.jl index a6b1a0d..330e703 100644 --- a/src/CompressedBeliefMDPs.jl +++ b/src/CompressedBeliefMDPs.jl @@ -30,11 +30,14 @@ export ManifoldCompressor, IsomapCompressor, ### Flux compressors ### - AutoencoderCompressor + AutoencoderCompressor, + VAECompressor include("compressors/compressor.jl") include("compressors/mvs_compressors.jl") include("compressors/manifold_compressors.jl") include("compressors/autoencoders.jl") +include("compressors/vae.jl") + export Sampler, diff --git a/src/compressors/autoencoders.jl b/src/compressors/autoencoders.jl index 5281504..f3bd8ec 100644 --- a/src/compressors/autoencoders.jl +++ b/src/compressors/autoencoders.jl @@ -1,5 +1,6 @@ using Flux + struct AutoencoderCompressor <: Compressor encoder model @@ -7,6 +8,9 @@ struct AutoencoderCompressor <: Compressor epochs end +""" +Implements an autoencoder in Flux. +""" function AutoencoderCompressor(input_dim::Integer, latent_dim::Integer; opt=Adam(), epochs=10) encoder = Dense(input_dim, latent_dim, sigmoid) |> f64 decoder = Chain(Dense(latent_dim => input_dim), softmax) @@ -18,7 +22,6 @@ function fit!(c::AutoencoderCompressor, beliefs) opt_state = Flux.setup(c.optimizer, c.model) data = [(beliefs', beliefs')] loss(m, x, y) = Flux.kldivergence(m(x), y) - # @showprogress for _ in 1:c.epochs for _ in 1:c.epochs Flux.train!(loss, c.model, data, opt_state) end @@ -27,44 +30,3 @@ end function (c::AutoencoderCompressor)(beliefs) return ndims(beliefs) == 2 ? c.encoder(beliefs')' : c.encoder(beliefs) end - - -# struct VAECompressor <: Compressor -# encoder -# model -# optimizer -# epochs -# end - -# # custom split layer from https://fluxml.ai/Flux.jl/dev/models/advanced/#Multiple-outputs:-a-custom-Split-layer -# struct Split{T} -# paths::T -# end - -# Split(paths...) = Split(paths) - -# Flux.@layer Split - -# (m::Split)(x::AbstractArray) = map(f -> f(x), m.paths) - - -# " -# Adapted from: https://github.com/FlyingWorkshop/DiffusionGNNTutorial -# " -# function VAECompressor(input_dim::Integer, hidden_dim::Integer=1, latent_dim::Integer; opt=Adam(), epochs=10) -# encoder = Chain( -# Dense(input_dim => hidden_dim, relu), -# Split(Dense(hidden_dim => latent_dim), Dense(hidden_dim => latent_dim)) -# ) |> f64 - -# function model(x) -# μ, σ = encoder(x) -# ϵ = randn(size(μ)...) - -# end - -# encoder = Dense(input_dim, latent_dim, sigmoid) |> f64 -# decoder = Chain(Dense(latent_dim => input_dim), softmax) -# model = Chain(encoder, decoder) |> f64 -# return AutoencoderCompressor(encoder, model, opt, epochs) -# end \ No newline at end of file diff --git a/src/compressors/vae.jl b/src/compressors/vae.jl new file mode 100644 index 0000000..9828e5f --- /dev/null +++ b/src/compressors/vae.jl @@ -0,0 +1,103 @@ +using Flux + + +" +Adapted from: +- https://github.com/FlyingWorkshop/DiffusionGNNTutorial +- https://github.com/FluxML/model-zoo/tree/master/vision/vae_mnist +" + + +struct Encoder + linear + μ + logσ +end + +Flux.@layer Encoder + +Encoder(input_dim::Integer, latent_dim::Integer, hidden_dim::Integer) = Encoder( + Dense(input_dim, hidden_dim, tanh), # linear + Dense(hidden_dim, latent_dim), # μ + Dense(hidden_dim, latent_dim), # logσ +) |> f64 + +function (encoder::Encoder)(x) + h = encoder.linear(x) + encoder.μ(h), encoder.logσ(h) +end + +Decoder(input_dim::Integer, latent_dim::Integer, hidden_dim::Integer) = Chain( + Dense(latent_dim, hidden_dim, tanh), + Dense(hidden_dim, input_dim) +) |> f64 + +function reconstuct(encoder, decoder, x) + μ, logσ = encoder(x) + z = μ + randn(size(logσ)...) .* exp.(logσ) + return μ, logσ, decoder(z) +end + +function model_loss(encoder, decoder, x) + μ, logσ, decoder_z = reconstuct(encoder, decoder, x) + kl_q_p = 0.5f0 * sum(@. (exp(2logσ) + μ^2 - 1 - 2logσ)) + logp_x_z = -Flux.logitbinarycrossentropy(decoder_z, x, agg=sum) + return -logp_x_z + kl_q_p +end + +struct VAECompressor <: Compressor + encoder + decoder + optimizer + epochs + verbose +end + + +""" +Implements a [VAE](https://arxiv.org/abs/1312.6114) in Flux. +""" +function VAECompressor(input_dim::Integer, latent_dim::Integer; hidden_dim::Integer=2, optimizer=Adam(), epochs::Integer=10, verbose=false) + encoder = Encoder(input_dim, latent_dim, hidden_dim) + decoder = Decoder(input_dim, latent_dim, hidden_dim) + VAECompressor(encoder, decoder, optimizer, epochs, verbose) +end + +function fit!(c::VAECompressor, beliefs) + encoder, decoder = c.encoder, c.decoder + opt_enc = Flux.setup(c.optimizer, encoder) + opt_dec = Flux.setup(c.optimizer, decoder) + + if c.verbose + println("Start Training, total $(c.epochs) epochs") + end + + for epoch = 1:c.epochs + if c.verbose + println("Epoch $(epoch)") + end + + for b in eachrow(beliefs) + loss, (grad_enc, grad_dec) = Flux.withgradient(encoder, decoder) do enc, dec + model_loss(enc, dec, b) + end + + Flux.update!(opt_enc, encoder, grad_enc) + Flux.update!(opt_dec, decoder, grad_dec) + + # progress meter + if c.verbose + @show loss + end + end + end +end + +function (c::VAECompressor)(beliefs) + if ndims(beliefs) == 2 + B̃ = c.encoder(beliefs')[1]' + else + B̃ = c.encoder(beliefs)[1] + end + return B̃ +end \ No newline at end of file diff --git a/test/compressor_tests.jl b/test/compressor_tests.jl index 741f5e9..3f4d90c 100644 --- a/test/compressor_tests.jl +++ b/test/compressor_tests.jl @@ -21,6 +21,7 @@ MANIFOLD_COMPRESSORS = ( FLUX_COMPRESSORS = ( AutoencoderCompressor, + VAECompressor ) @testset "Compressor Tests" begin