Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed Jul 15, 2024
1 parent 7371e41 commit 2e11667
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 34 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
ParticleFilters = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand Down
52 changes: 49 additions & 3 deletions figures/kl.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,58 @@
using Revise
using Infiltrator

using Plots
using Random
using MultivariateStats

using POMDPs
using POMDPTools
using CompressedBeliefMDPs

Random.seed!(1)

pomdp = CircularMaze(2, 50)
sampler = PolicySampler(pomdp)
pomdp = CircularMaze(5, 100)
# pomdp = TMaze()
sampler = PolicySampler(pomdp, n=500)
compressor = PCACompressor(10)

# get beliefs
B = sampler(pomdp)

belief = B[1]
# get compressed beliefs
B_numerical = mapreduce(b->convert_s(AbstractArray{Float64}, b, pomdp), hcat, B)' |> Matrix
B_numerical = B_numerical[:,1:end-1] # ignore belief in TerminalState
fit!(compressor, B_numerical)
= compressor(B_numerical)
B_recon = reconstruct(compressor.M, B̃')'

# TODO: add comparison to our reconstruction

function plot_beliefs(original, reconstructed)
# Define x-axis (assuming the states are ordered sequentially)
x = 1:length(original)

# Plot the beliefs
plot(x, original, label="Original Belief", linestyle=:solid, linewidth=2)
plot!(x, reconstructed, label="Reconstructed Belief", linestyle=:dash, linewidth=2)

# Add labels and title
xlabel!("State")
ylabel!("Probability")
title!("An Example Belief and Reconstruction")
end

# Loop through indices and save each plot in a folder
plots_dir = "plots"
if !isdir(plots_dir)
mkdir(plots_dir)
end

@show size(compressor.M)

for i in 1:size(B_numerical, 1)
original = B_numerical[i, :]
reconstructed = B_recon[i, :]
plot_beliefs(original, reconstructed)
savefig(joinpath(plots_dir, "belief_plot_$i.png"))
end
61 changes: 57 additions & 4 deletions figures/l2.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,58 @@
using
using Revise
using Infiltrator

pomdp = CircularMaze(2, 50, 0.99)
s = initialstate(pomdp)
plot_belief(s)
using Plots
using Random
using MultivariateStats

using POMDPs
using POMDPTools
using CompressedBeliefMDPs

Random.seed!(1)

pomdp = CircularMaze(5, 100)
# pomdp = TMaze()
sampler = PolicySampler(pomdp, n=500)
compressor = PCACompressor(10)

# get beliefs
B = sampler(pomdp)

# get compressed beliefs
B_numerical = mapreduce(b->convert_s(AbstractArray{Float64}, b, pomdp), hcat, B)' |> Matrix
B_numerical = B_numerical[:,1:end-1] # ignore belief in TerminalState
fit!(compressor, B_numerical)
= compressor(B_numerical)
B_recon = reconstruct(compressor.M, B̃')'

# TODO: add comparison to our reconstruction

function plot_beliefs(original, reconstructed)
# Define x-axis (assuming the states are ordered sequentially)
x = 1:length(original)

# Plot the beliefs
plot(x, original, label="Original Belief", linestyle=:solid, linewidth=2)
plot!(x, reconstructed, label="Reconstructed Belief", linestyle=:dash, linewidth=2)

# Add labels and title
xlabel!("State")
ylabel!("Probability")
title!("An Example Belief and Reconstruction")
end

# Loop through indices and save each plot in a folder
plots_dir = "plots"
if !isdir(plots_dir)
mkdir(plots_dir)
end

@show size(compressor.M)

for i in 1:size(B_numerical, 1)
original = B_numerical[i, :]
reconstructed = B_recon[i, :]
plot_beliefs(original, reconstructed)
savefig(joinpath(plots_dir, "belief_plot_$i.png"))
end
59 changes: 58 additions & 1 deletion figures/recon.jl
Original file line number Diff line number Diff line change
@@ -1 +1,58 @@
using CompressedBeliefMDP
using Revise
using Infiltrator

using Plots
using Random
using MultivariateStats

using POMDPs
using POMDPTools
using CompressedBeliefMDPs

Random.seed!(1)

pomdp = CircularMaze(5, 100)
# pomdp = TMaze()
sampler = PolicySampler(pomdp, n=500)
compressor = PCACompressor(10)

# get beliefs
B = sampler(pomdp)

# get compressed beliefs
B_numerical = mapreduce(b->convert_s(AbstractArray{Float64}, b, pomdp), hcat, B)' |> Matrix
B_numerical = B_numerical[:,1:end-1] # ignore belief in TerminalState
fit!(compressor, B_numerical)
= compressor(B_numerical)
B_recon = reconstruct(compressor.M, B̃')'

# TODO: add comparison to our reconstruction

function plot_beliefs(original, reconstructed)
# Define x-axis (assuming the states are ordered sequentially)
x = 1:length(original)

# Plot the beliefs
plot(x, original, label="Original Belief", linestyle=:solid, linewidth=2)
plot!(x, reconstructed, label="Reconstructed Belief", linestyle=:dash, linewidth=2)

# Add labels and title
xlabel!("State")
ylabel!("Probability")
title!("An Example Belief and Reconstruction")
end

# Loop through indices and save each plot in a folder
plots_dir = "plots"
if !isdir(plots_dir)
mkdir(plots_dir)
end

@show size(compressor.M)

for i in 1:size(B_numerical, 1)
original = B_numerical[i, :]
reconstructed = B_recon[i, :]
plot_beliefs(original, reconstructed)
savefig(joinpath(plots_dir, "belief_plot_$i.png"))
end
1 change: 1 addition & 0 deletions src/CompressedBeliefMDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using StaticArrays
using Combinatorics
using IterTools
using Plots
using ProgressMeter

using LinearAlgebra
using Parameters
Expand Down
54 changes: 30 additions & 24 deletions src/envs/circular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ struct CircularMaze <: POMDP{

r_findgoal::Float64
r_timestep_penalty::Float64

states::AbstractArray
goals::AbstractArray
end

Expand Down Expand Up @@ -56,6 +58,18 @@ function _make_probabilities(corridor_length::Integer)
return probabilities
end

function _make_states(n_corridors, corridor_length)
space = Union{CircularMazeState, TerminalState}[]
for i 1:n_corridors
for j 1:corridor_length
state = CircularMazeState(i, j)
push!(space, state)
end
end
push!(space, terminalstate)
return space
end

function CircularMaze(
n_corridors::Integer,
corridor_length::Integer;
Expand All @@ -78,12 +92,18 @@ function CircularMaze(

probabilities = _make_probabilities(corridor_length)
center = div(corridor_length, 2) + 1

# make states
states = _make_states(n_corridors, corridor_length)

# make goals
goals = []
positions = rand(rng, 1:corridor_length, n_corridors)
for (corridor, x) in enumerate(positions)
s = CircularMazeState(corridor, x)
push!(goals, s)
end

pomdp = CircularMaze(
n_corridors,
corridor_length,
Expand All @@ -92,6 +112,7 @@ function CircularMaze(
discount,
r_findgoal,
r_timestep_penalty,
states,
goals
)
return pomdp
Expand All @@ -116,13 +137,7 @@ function CircularMaze(
end

function CircularMaze()
pomdp = CircularMaze(
n_corridors=2,
corridor_length=200,
discount=0.99,
r_findgoal=1,
r_timestep_penalty=0
)
pomdp = CircularMaze(2, 100)
return pomdp
end

Expand All @@ -149,14 +164,7 @@ function POMDPs.actionindex(::CircularMaze, a::Integer)
end

function POMDPs.states(pomdp::CircularMaze)
space = statetype(pomdp)[]
for i 1:pomdp.n_corridors
for j 1:pomdp.corridor_length
state = CircularMazeState(i, j)
push!(space, state)
end
end
push!(space, terminalstate)
space = pomdp.states
return space
end

Expand All @@ -172,8 +180,7 @@ end

# the initial state distribution is a von Mises distributions each over corridor with a mean at the center
function POMDPs.initialstate(pomdp::CircularMaze)
probabilities = repeat(pomdp.probabilities, pomdp.n_corridors)
probabilities /= pomdp.n_corridors # normalize values to sum to 1
probabilities = repeat(pomdp.probabilities ./ pomdp.n_corridors, pomdp.n_corridors)
values = states(pomdp)
push!(probabilities, 0) # OBOE from terminal state
d = SparseCat(values, probabilities)
Expand Down Expand Up @@ -208,8 +215,8 @@ function POMDPs.observation(
return obs
end

function POMDPs.observation(pomdp::CircularMaze, s::TerminalState)
return Deterministic(terminalstate)
function POMDPs.observation(::CircularMaze, s::TerminalState)
return Deterministic(s)
end

function POMDPs.observations(pomdp::CircularMaze)
Expand Down Expand Up @@ -245,11 +252,10 @@ function POMDPs.transition(
x = s.x
end
corridor = s.corridor
corridor_states = []
for x_ 1:pomdp.corridor_length
s_ = CircularMazeState(corridor, x_)
push!(corridor_states, s_)
end
states = pomdp.states
start = (corridor - 1) * pomdp.corridor_length + 1
stop = start + pomdp.corridor_length
corridor_states = states[start:stop]
probabilities = _center_probabilities(pomdp, x)
d = SparseCat(corridor_states, probabilities)
end
Expand Down
14 changes: 12 additions & 2 deletions src/samplers/rollout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ end
function (s::PolicySampler)(pomdp::POMDP)
B = []
mdp = GenerativeBeliefMDP(pomdp, s.updater)
progress = Progress(s.n, 1) # Initialize the progress bar

while true
b = initialstate(mdp).val
for _ in 1:s.n
Expand All @@ -61,6 +63,7 @@ function (s::PolicySampler)(pomdp::POMDP)
end
b = @gen(:sp)(mdp, b, a, s.rng)
push!(B, b)
next!(progress) # Update the progress bar
end
end
return B
Expand All @@ -69,7 +72,9 @@ end

function (s::PolicySampler)(pomdp::CircularMaze)
B = []
mdp = GenerativeBeliefMDP(pomdp, s.updater)
mdp = GenerativeBeliefMDP(pomdp, s.updater)
progress = Progress(s.n, 1) # Initialize the progress bar

while true
b = initialstate(mdp).val
for _ in 1:s.n
Expand All @@ -83,6 +88,7 @@ function (s::PolicySampler)(pomdp::CircularMaze)
break
else
push!(B, b)
next!(progress) # Update the progress bar
end
end
end
Expand All @@ -91,6 +97,9 @@ end






"""
ExplorationPolicySampler
Expand Down Expand Up @@ -146,9 +155,10 @@ function ExplorationPolicySampler(pomdp::POMDP;
end



function (s::ExplorationPolicySampler)(pomdp::POMDP)
B = []
mdp = GenerativeBeliefMDP(pomdp, s.updater)
mdp = GenerativeBeliefMDP(pomdp, s.updater)
while true
b = initialstate(mdp).val
for k in 1:s.n
Expand Down

0 comments on commit 2e11667

Please sign in to comment.