Skip to content

Commit

Permalink
off by ones
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed Jul 12, 2024
1 parent 6d22eab commit 1308eec
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 30 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
ParticleFilters = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
Expand Down
13 changes: 13 additions & 0 deletions figures/kl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using Plots
using POMDPs
using POMDPTools
using CompressedBeliefMDPs


pomdp = CircularMaze(2, 50)
sampler = PolicySampler(pomdp)
B = sampler(pomdp)

belief = B[1]


5 changes: 5 additions & 0 deletions figures/l2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using

pomdp = CircularMaze(2, 50, 0.99)
s = initialstate(pomdp)
plot_belief(s)
1 change: 1 addition & 0 deletions figures/recon.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
using CompressedBeliefMDP
11 changes: 7 additions & 4 deletions src/CompressedBeliefMDPs.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module CompressedBeliefMDPs

using Infiltrator # TODO: remove
using Infiltrator
# TODO: remove plots, revise, Infiltrator

# Packages from JuliaPOMDPs
using POMDPs, POMDPTools, POMDPModels
using POMDPs
using POMDPTools
using POMDPModels
using ParticleFilters
using LocalApproximationValueIteration
using LocalFunctionApproximation
Expand All @@ -14,7 +17,8 @@ using Bijections
using NearestNeighbors
using StaticArrays
using Combinatorics
using IterTools # TODO: not sure if I need this anymore
using IterTools
using Plots

using LinearAlgebra
using Parameters
Expand All @@ -30,7 +34,6 @@ export
CMAZE_DECLARE_GOAL
include("envs/circular.jl")


export
### Compressor Interface ###
Compressor,
Expand Down
76 changes: 63 additions & 13 deletions src/envs/circular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@ struct CircularMaze <: POMDP{
discount::Float64

r_findgoal::Float64
r_timestep_penalty::Float64
goals::AbstractArray
# TODO: add RNG support
end

# get the mass of each state
function _get_mass(d, x1, x2)
@assert x2 >= x1
@assert minimum(d) <= x1 && x1 <= maximum(d)
@assert minimum(d) <= x2 && x2 <= maximum(d)
@assert x2 >= x1 "x2 ($x2) should be greater than or equal to x1 ($x1)"
@assert minimum(d) <= x1 <= maximum(d) """
x1 ($x1) should be within the distribution's range
(minimum: $(minimum(d)), maximum: $(maximum(d)))
"""
@assert minimum(d) <= x2 <= maximum(d) """
x2 ($x2) should be within the distribution's range
(minimum: $(minimum(d)), maximum: $(maximum(d)))
"""

c1 = cdf(d, x1)
c2 = cdf(d, x2)
Expand All @@ -37,28 +43,38 @@ function _make_probabilities(corridor_length::Integer)

d = VonMises()
min_ = minimum(d)
max_ = maximum(d)
step_size = (max_ - min_) / corridor_length
max_ = maximum(d)
a = range(min_, max_, length=corridor_length + 1)
a1 = a[1:end-1]
a2 = a[2:end]
probabilities = []
for x1 in min_:step_size:(max_ - 1)
x2 = x1 + step_size
for (x1, x2) in zip(a1, a2)
m = _get_mass(d, x1, x2)
push!(probabilities, m)
end

return probabilities
end

function CircularMaze(
n_corridors::Integer,
corridor_length::Integer,
discount::Float64;
corridor_length::Integer;
discount::Float64 = 0.99,
r_findgoal::Float64 = 1.0,
r_timestep_penalty::Float64 = 0.0,
rng::AbstractRNG = MersenneTwister()
)
@assert n_corridors > 0 "Number of corridors must be a positive integer."
@assert corridor_length > 0 "Corridor length must be a positive integer."
@assert 0.0 <= discount <= 1.0 "Discount factor must be between 0 and 1."
@assert r_findgoal >= 0.0 "Reward for finding the goal must be non-negative."
@assert r_timestep_penalty >= 0.0 "The timestep penalty must be non-negative."

if typeof(n_corridors) != typeof(corridor_length)
type1 = typeof(n_corridors)
type2 = typeof(corridor_length)
@warn "n_corridors ($type1) and corridor_length ($type2) are not of the same type."
end

probabilities = _make_probabilities(corridor_length)
center = div(corridor_length, 2) + 1
Expand All @@ -68,12 +84,45 @@ function CircularMaze(
s = CircularMazeState(corridor, x)
push!(goals, s)
end
pomdp = CircularMaze(n_corridors, corridor_length, probabilities, center, discount, r_findgoal, goals)
pomdp = CircularMaze(
n_corridors,
corridor_length,
probabilities,
center,
discount,
r_findgoal,
r_timestep_penalty,
goals
)
return pomdp
end

# conveience constructors
function CircularMaze(
n_corridors::Integer,
corridor_length::Integer,
discount::Float64,
r_findgoal::Float64,
r_timestep_penalty::Float64,
)
pomdp = CircularMaze(
n_corridors,
corridor_length;
discount,
r_findgoal=r_findgoal,
r_timestep_penalty=r_timestep_penalty,
)
return pomdp
end

function CircularMaze()
pomdp = CircularMaze(2, 200, 0.99)
pomdp = CircularMaze(
n_corridors=2,
corridor_length=200,
discount=0.99,
r_findgoal=1,
r_timestep_penalty=0
)
return pomdp
end

Expand Down Expand Up @@ -229,6 +278,7 @@ function POMDPs.reward(
else
r = 0
end
r -= pomdp.r_timestep_penalty
return r
end

Expand All @@ -247,7 +297,7 @@ function POMDPs.discount(pomdp::CircularMaze)
end

## hack to avoid exploring terminal states
global CMAZE_TERMINAL_FLAG = false
CMAZE_TERMINAL_FLAG = false
function POMDPTools.ModelTools.gbmdp_handle_terminal(::CircularMaze, ::Updater, b, s, a, rng)
global CMAZE_TERMINAL_FLAG = true
return b
Expand Down
3 changes: 2 additions & 1 deletion src/samplers/expansion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ function _exploratory_belief_expansion!(
end

function (s::BeliefExpansionSampler)(pomdp::POMDP)
b0 = initialize_belief(s.updater, initialstate(pomdp))
s0 = initialstate(pomdp)
b0 = initialize_belief(s.updater, s0)
b0_numeric = _make_numeric(b0, pomdp)
B = Set([b0])
B_numeric = [b0_numeric]
Expand Down
26 changes: 16 additions & 10 deletions test/circular_tests.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
let
@testset "CircularMaze" begin
pomdp = CircularMaze(2, 5, 0.99)
@test has_consistent_distributions(pomdp)

# test non-exported solvers
# solvers = [
# MCTSSolver(n_iterations=10, depth=5, exploration_constant=5.0),
# ]
# @testset "$solver" for solver in solvers
# @test test_solver(solver, pomdp)
# # @test_nowarn test_solver(solver, pomdp)
# end
@testset "Solvers" begin
@testset "MCTS" begin
solver = MCTSSolver(n_iterations=10, depth=5, exploration_constant=5.0)
@test_nowarn test_solver(solver, pomdp)
end

# @testset "CompressedSolver" begin
# # TODO: compressed solver
# solver = MCTSSolver(n_iterations=10, depth=5, exploration_constant=5.0)
# # @test_nowarn test_solver(solver, pomdp)
# end
end


# test CompressedBeliefSolver
@testset "Samplers" begin
@testset "PolicySampler" begin
@test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=PolicySampler(pomdp)), pomdp)
end
@testset "ExplorationPolicySampler" begin
@test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=PolicySampler(pomdp)), pomdp)
@test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=ExplorationPolicySampler(pomdp)), pomdp)
end
@testset "BeliefExpansionSampler" begin
@test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=PolicySampler(pomdp)), pomdp)
@test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=BeliefExpansionSampler(pomdp)), pomdp)
end
end

Expand Down
4 changes: 2 additions & 2 deletions test/solver_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
@test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=PolicySampler(pomdp)), pomdp)
end
@testset "ExplorationPolicySampler" begin
@test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=PolicySampler(pomdp)), pomdp)
@test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=ExplorationPolicySampler(pomdp)), pomdp)
end
@testset "BeliefExpansionSampler" begin
@test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=PolicySampler(pomdp)), pomdp)
@test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=BeliefExpansionSampler(pomdp)), pomdp)
end
end
end
Expand Down

0 comments on commit 1308eec

Please sign in to comment.