From 2e1166751c221d9181d343dda07e9078d00d53c0 Mon Sep 17 00:00:00 2001 From: Logan Bhamidipaty Date: Sun, 14 Jul 2024 18:16:17 -0700 Subject: [PATCH] changes --- Project.toml | 1 + figures/kl.jl | 52 +++++++++++++++++++++++++++++-- figures/l2.jl | 61 ++++++++++++++++++++++++++++++++++--- figures/recon.jl | 59 ++++++++++++++++++++++++++++++++++- src/CompressedBeliefMDPs.jl | 1 + src/envs/circular.jl | 54 +++++++++++++++++--------------- src/samplers/rollout.jl | 14 +++++++-- 7 files changed, 208 insertions(+), 34 deletions(-) diff --git a/Project.toml b/Project.toml index f3bbf50..60fa402 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" ParticleFilters = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" diff --git a/figures/kl.jl b/figures/kl.jl index 22ce612..791cd53 100644 --- a/figures/kl.jl +++ b/figures/kl.jl @@ -1,12 +1,58 @@ +using Revise +using Infiltrator + using Plots +using Random +using MultivariateStats + using POMDPs using POMDPTools using CompressedBeliefMDPs +Random.seed!(1) -pomdp = CircularMaze(2, 50) -sampler = PolicySampler(pomdp) +pomdp = CircularMaze(5, 100) +# pomdp = TMaze() +sampler = PolicySampler(pomdp, n=500) +compressor = PCACompressor(10) + +# get beliefs B = sampler(pomdp) -belief = B[1] +# get compressed beliefs +B_numerical = mapreduce(b->convert_s(AbstractArray{Float64}, b, pomdp), hcat, B)' |> Matrix +B_numerical = B_numerical[:,1:end-1] # ignore belief in TerminalState +fit!(compressor, B_numerical) +B̃ = compressor(B_numerical) +B_recon = reconstruct(compressor.M, B̃')' + +# TODO: add comparison to our reconstruction + +function plot_beliefs(original, reconstructed) + # Define x-axis (assuming the states are ordered sequentially) + x = 1:length(original) + + # Plot the beliefs + plot(x, original, label="Original Belief", linestyle=:solid, linewidth=2) + plot!(x, reconstructed, label="Reconstructed Belief", linestyle=:dash, linewidth=2) + + # Add labels and title + xlabel!("State") + ylabel!("Probability") + title!("An Example Belief and Reconstruction") +end + +# Loop through indices and save each plot in a folder +plots_dir = "plots" +if !isdir(plots_dir) + mkdir(plots_dir) +end + +@show size(compressor.M) +for i in 1:size(B_numerical, 1) + original = B_numerical[i, :] + reconstructed = B_recon[i, :] + plot_beliefs(original, reconstructed) + savefig(joinpath(plots_dir, "belief_plot_$i.png")) +end diff --git a/figures/l2.jl b/figures/l2.jl index e8e897e..791cd53 100644 --- a/figures/l2.jl +++ b/figures/l2.jl @@ -1,5 +1,58 @@ -using +using Revise +using Infiltrator -pomdp = CircularMaze(2, 50, 0.99) -s = initialstate(pomdp) -plot_belief(s) +using Plots +using Random +using MultivariateStats + +using POMDPs +using POMDPTools +using CompressedBeliefMDPs + +Random.seed!(1) + +pomdp = CircularMaze(5, 100) +# pomdp = TMaze() +sampler = PolicySampler(pomdp, n=500) +compressor = PCACompressor(10) + +# get beliefs +B = sampler(pomdp) + +# get compressed beliefs +B_numerical = mapreduce(b->convert_s(AbstractArray{Float64}, b, pomdp), hcat, B)' |> Matrix +B_numerical = B_numerical[:,1:end-1] # ignore belief in TerminalState +fit!(compressor, B_numerical) +B̃ = compressor(B_numerical) +B_recon = reconstruct(compressor.M, B̃')' + +# TODO: add comparison to our reconstruction + +function plot_beliefs(original, reconstructed) + # Define x-axis (assuming the states are ordered sequentially) + x = 1:length(original) + + # Plot the beliefs + plot(x, original, label="Original Belief", linestyle=:solid, linewidth=2) + plot!(x, reconstructed, label="Reconstructed Belief", linestyle=:dash, linewidth=2) + + # Add labels and title + xlabel!("State") + ylabel!("Probability") + title!("An Example Belief and Reconstruction") +end + +# Loop through indices and save each plot in a folder +plots_dir = "plots" +if !isdir(plots_dir) + mkdir(plots_dir) +end + +@show size(compressor.M) + +for i in 1:size(B_numerical, 1) + original = B_numerical[i, :] + reconstructed = B_recon[i, :] + plot_beliefs(original, reconstructed) + savefig(joinpath(plots_dir, "belief_plot_$i.png")) +end diff --git a/figures/recon.jl b/figures/recon.jl index 77a24b5..791cd53 100644 --- a/figures/recon.jl +++ b/figures/recon.jl @@ -1 +1,58 @@ -using CompressedBeliefMDP \ No newline at end of file +using Revise +using Infiltrator + +using Plots +using Random +using MultivariateStats + +using POMDPs +using POMDPTools +using CompressedBeliefMDPs + +Random.seed!(1) + +pomdp = CircularMaze(5, 100) +# pomdp = TMaze() +sampler = PolicySampler(pomdp, n=500) +compressor = PCACompressor(10) + +# get beliefs +B = sampler(pomdp) + +# get compressed beliefs +B_numerical = mapreduce(b->convert_s(AbstractArray{Float64}, b, pomdp), hcat, B)' |> Matrix +B_numerical = B_numerical[:,1:end-1] # ignore belief in TerminalState +fit!(compressor, B_numerical) +B̃ = compressor(B_numerical) +B_recon = reconstruct(compressor.M, B̃')' + +# TODO: add comparison to our reconstruction + +function plot_beliefs(original, reconstructed) + # Define x-axis (assuming the states are ordered sequentially) + x = 1:length(original) + + # Plot the beliefs + plot(x, original, label="Original Belief", linestyle=:solid, linewidth=2) + plot!(x, reconstructed, label="Reconstructed Belief", linestyle=:dash, linewidth=2) + + # Add labels and title + xlabel!("State") + ylabel!("Probability") + title!("An Example Belief and Reconstruction") +end + +# Loop through indices and save each plot in a folder +plots_dir = "plots" +if !isdir(plots_dir) + mkdir(plots_dir) +end + +@show size(compressor.M) + +for i in 1:size(B_numerical, 1) + original = B_numerical[i, :] + reconstructed = B_recon[i, :] + plot_beliefs(original, reconstructed) + savefig(joinpath(plots_dir, "belief_plot_$i.png")) +end diff --git a/src/CompressedBeliefMDPs.jl b/src/CompressedBeliefMDPs.jl index 42b8952..b404838 100644 --- a/src/CompressedBeliefMDPs.jl +++ b/src/CompressedBeliefMDPs.jl @@ -19,6 +19,7 @@ using StaticArrays using Combinatorics using IterTools using Plots +using ProgressMeter using LinearAlgebra using Parameters diff --git a/src/envs/circular.jl b/src/envs/circular.jl index cee9057..b540576 100644 --- a/src/envs/circular.jl +++ b/src/envs/circular.jl @@ -16,6 +16,8 @@ struct CircularMaze <: POMDP{ r_findgoal::Float64 r_timestep_penalty::Float64 + + states::AbstractArray goals::AbstractArray end @@ -56,6 +58,18 @@ function _make_probabilities(corridor_length::Integer) return probabilities end +function _make_states(n_corridors, corridor_length) + space = Union{CircularMazeState, TerminalState}[] + for i ∈ 1:n_corridors + for j ∈ 1:corridor_length + state = CircularMazeState(i, j) + push!(space, state) + end + end + push!(space, terminalstate) + return space +end + function CircularMaze( n_corridors::Integer, corridor_length::Integer; @@ -78,12 +92,18 @@ function CircularMaze( probabilities = _make_probabilities(corridor_length) center = div(corridor_length, 2) + 1 + + # make states + states = _make_states(n_corridors, corridor_length) + + # make goals goals = [] positions = rand(rng, 1:corridor_length, n_corridors) for (corridor, x) in enumerate(positions) s = CircularMazeState(corridor, x) push!(goals, s) end + pomdp = CircularMaze( n_corridors, corridor_length, @@ -92,6 +112,7 @@ function CircularMaze( discount, r_findgoal, r_timestep_penalty, + states, goals ) return pomdp @@ -116,13 +137,7 @@ function CircularMaze( end function CircularMaze() - pomdp = CircularMaze( - n_corridors=2, - corridor_length=200, - discount=0.99, - r_findgoal=1, - r_timestep_penalty=0 - ) + pomdp = CircularMaze(2, 100) return pomdp end @@ -149,14 +164,7 @@ function POMDPs.actionindex(::CircularMaze, a::Integer) end function POMDPs.states(pomdp::CircularMaze) - space = statetype(pomdp)[] - for i ∈ 1:pomdp.n_corridors - for j ∈ 1:pomdp.corridor_length - state = CircularMazeState(i, j) - push!(space, state) - end - end - push!(space, terminalstate) + space = pomdp.states return space end @@ -172,8 +180,7 @@ end # the initial state distribution is a von Mises distributions each over corridor with a mean at the center function POMDPs.initialstate(pomdp::CircularMaze) - probabilities = repeat(pomdp.probabilities, pomdp.n_corridors) - probabilities /= pomdp.n_corridors # normalize values to sum to 1 + probabilities = repeat(pomdp.probabilities ./ pomdp.n_corridors, pomdp.n_corridors) values = states(pomdp) push!(probabilities, 0) # OBOE from terminal state d = SparseCat(values, probabilities) @@ -208,8 +215,8 @@ function POMDPs.observation( return obs end -function POMDPs.observation(pomdp::CircularMaze, s::TerminalState) - return Deterministic(terminalstate) +function POMDPs.observation(::CircularMaze, s::TerminalState) + return Deterministic(s) end function POMDPs.observations(pomdp::CircularMaze) @@ -245,11 +252,10 @@ function POMDPs.transition( x = s.x end corridor = s.corridor - corridor_states = [] - for x_ ∈ 1:pomdp.corridor_length - s_ = CircularMazeState(corridor, x_) - push!(corridor_states, s_) - end + states = pomdp.states + start = (corridor - 1) * pomdp.corridor_length + 1 + stop = start + pomdp.corridor_length + corridor_states = states[start:stop] probabilities = _center_probabilities(pomdp, x) d = SparseCat(corridor_states, probabilities) end diff --git a/src/samplers/rollout.jl b/src/samplers/rollout.jl index 1a0751f..69a1131 100644 --- a/src/samplers/rollout.jl +++ b/src/samplers/rollout.jl @@ -49,6 +49,8 @@ end function (s::PolicySampler)(pomdp::POMDP) B = [] mdp = GenerativeBeliefMDP(pomdp, s.updater) + progress = Progress(s.n, 1) # Initialize the progress bar + while true b = initialstate(mdp).val for _ in 1:s.n @@ -61,6 +63,7 @@ function (s::PolicySampler)(pomdp::POMDP) end b = @gen(:sp)(mdp, b, a, s.rng) push!(B, b) + next!(progress) # Update the progress bar end end return B @@ -69,7 +72,9 @@ end function (s::PolicySampler)(pomdp::CircularMaze) B = [] - mdp = GenerativeBeliefMDP(pomdp, s.updater) + mdp = GenerativeBeliefMDP(pomdp, s.updater) + progress = Progress(s.n, 1) # Initialize the progress bar + while true b = initialstate(mdp).val for _ in 1:s.n @@ -83,6 +88,7 @@ function (s::PolicySampler)(pomdp::CircularMaze) break else push!(B, b) + next!(progress) # Update the progress bar end end end @@ -91,6 +97,9 @@ end + + + """ ExplorationPolicySampler @@ -146,9 +155,10 @@ function ExplorationPolicySampler(pomdp::POMDP; end + function (s::ExplorationPolicySampler)(pomdp::POMDP) B = [] - mdp = GenerativeBeliefMDP(pomdp, s.updater) + mdp = GenerativeBeliefMDP(pomdp, s.updater) while true b = initialstate(mdp).val for k in 1:s.n