diff --git a/Project.toml b/Project.toml index 233554d..f3bbf50 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,9 @@ POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7" POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" ParticleFilters = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] diff --git a/figures/kl.jl b/figures/kl.jl new file mode 100644 index 0000000..a8ce402 --- /dev/null +++ b/figures/kl.jl @@ -0,0 +1,13 @@ +using Plots +using POMDPs +using POMDPTools +using CompressedBeliefMDPs + + +pomdp = CircularMaze(2, 50) +sampler = PolicySampler(pomdp) +B = sampler(pomdp) + +belief = B[1] + + diff --git a/figures/l2.jl b/figures/l2.jl new file mode 100644 index 0000000..e8e897e --- /dev/null +++ b/figures/l2.jl @@ -0,0 +1,5 @@ +using + +pomdp = CircularMaze(2, 50, 0.99) +s = initialstate(pomdp) +plot_belief(s) diff --git a/figures/recon.jl b/figures/recon.jl new file mode 100644 index 0000000..77a24b5 --- /dev/null +++ b/figures/recon.jl @@ -0,0 +1 @@ +using CompressedBeliefMDP \ No newline at end of file diff --git a/src/CompressedBeliefMDPs.jl b/src/CompressedBeliefMDPs.jl index e60d207..42b8952 100644 --- a/src/CompressedBeliefMDPs.jl +++ b/src/CompressedBeliefMDPs.jl @@ -1,9 +1,12 @@ module CompressedBeliefMDPs -using Infiltrator # TODO: remove +using Infiltrator +# TODO: remove plots, revise, Infiltrator # Packages from JuliaPOMDPs -using POMDPs, POMDPTools, POMDPModels +using POMDPs +using POMDPTools +using POMDPModels using ParticleFilters using LocalApproximationValueIteration using LocalFunctionApproximation @@ -14,7 +17,8 @@ using Bijections using NearestNeighbors using StaticArrays using Combinatorics -using IterTools # TODO: not sure if I need this anymore +using IterTools +using Plots using LinearAlgebra using Parameters @@ -30,7 +34,6 @@ export CMAZE_DECLARE_GOAL include("envs/circular.jl") - export ### Compressor Interface ### Compressor, diff --git a/src/envs/circular.jl b/src/envs/circular.jl index e412b1a..cee9057 100644 --- a/src/envs/circular.jl +++ b/src/envs/circular.jl @@ -15,15 +15,21 @@ struct CircularMaze <: POMDP{ discount::Float64 r_findgoal::Float64 + r_timestep_penalty::Float64 goals::AbstractArray - # TODO: add RNG support end # get the mass of each state function _get_mass(d, x1, x2) - @assert x2 >= x1 - @assert minimum(d) <= x1 && x1 <= maximum(d) - @assert minimum(d) <= x2 && x2 <= maximum(d) + @assert x2 >= x1 "x2 ($x2) should be greater than or equal to x1 ($x1)" + @assert minimum(d) <= x1 <= maximum(d) """ + x1 ($x1) should be within the distribution's range + (minimum: $(minimum(d)), maximum: $(maximum(d))) + """ + @assert minimum(d) <= x2 <= maximum(d) """ + x2 ($x2) should be within the distribution's range + (minimum: $(minimum(d)), maximum: $(maximum(d))) + """ c1 = cdf(d, x1) c2 = cdf(d, x2) @@ -37,28 +43,38 @@ function _make_probabilities(corridor_length::Integer) d = VonMises() min_ = minimum(d) - max_ = maximum(d) - step_size = (max_ - min_) / corridor_length + max_ = maximum(d) + a = range(min_, max_, length=corridor_length + 1) + a1 = a[1:end-1] + a2 = a[2:end] probabilities = [] - for x1 in min_:step_size:(max_ - 1) - x2 = x1 + step_size + for (x1, x2) in zip(a1, a2) m = _get_mass(d, x1, x2) push!(probabilities, m) end + return probabilities end function CircularMaze( n_corridors::Integer, - corridor_length::Integer, - discount::Float64; + corridor_length::Integer; + discount::Float64 = 0.99, r_findgoal::Float64 = 1.0, + r_timestep_penalty::Float64 = 0.0, rng::AbstractRNG = MersenneTwister() ) @assert n_corridors > 0 "Number of corridors must be a positive integer." @assert corridor_length > 0 "Corridor length must be a positive integer." @assert 0.0 <= discount <= 1.0 "Discount factor must be between 0 and 1." @assert r_findgoal >= 0.0 "Reward for finding the goal must be non-negative." + @assert r_timestep_penalty >= 0.0 "The timestep penalty must be non-negative." + + if typeof(n_corridors) != typeof(corridor_length) + type1 = typeof(n_corridors) + type2 = typeof(corridor_length) + @warn "n_corridors ($type1) and corridor_length ($type2) are not of the same type." + end probabilities = _make_probabilities(corridor_length) center = div(corridor_length, 2) + 1 @@ -68,12 +84,45 @@ function CircularMaze( s = CircularMazeState(corridor, x) push!(goals, s) end - pomdp = CircularMaze(n_corridors, corridor_length, probabilities, center, discount, r_findgoal, goals) + pomdp = CircularMaze( + n_corridors, + corridor_length, + probabilities, + center, + discount, + r_findgoal, + r_timestep_penalty, + goals + ) + return pomdp +end + +# conveience constructors +function CircularMaze( + n_corridors::Integer, + corridor_length::Integer, + discount::Float64, + r_findgoal::Float64, + r_timestep_penalty::Float64, +) + pomdp = CircularMaze( + n_corridors, + corridor_length; + discount, + r_findgoal=r_findgoal, + r_timestep_penalty=r_timestep_penalty, + ) return pomdp end function CircularMaze() - pomdp = CircularMaze(2, 200, 0.99) + pomdp = CircularMaze( + n_corridors=2, + corridor_length=200, + discount=0.99, + r_findgoal=1, + r_timestep_penalty=0 + ) return pomdp end @@ -229,6 +278,7 @@ function POMDPs.reward( else r = 0 end + r -= pomdp.r_timestep_penalty return r end @@ -247,7 +297,7 @@ function POMDPs.discount(pomdp::CircularMaze) end ## hack to avoid exploring terminal states -global CMAZE_TERMINAL_FLAG = false +CMAZE_TERMINAL_FLAG = false function POMDPTools.ModelTools.gbmdp_handle_terminal(::CircularMaze, ::Updater, b, s, a, rng) global CMAZE_TERMINAL_FLAG = true return b diff --git a/src/samplers/expansion.jl b/src/samplers/expansion.jl index e709d6b..c4f7b98 100644 --- a/src/samplers/expansion.jl +++ b/src/samplers/expansion.jl @@ -100,7 +100,8 @@ function _exploratory_belief_expansion!( end function (s::BeliefExpansionSampler)(pomdp::POMDP) - b0 = initialize_belief(s.updater, initialstate(pomdp)) + s0 = initialstate(pomdp) + b0 = initialize_belief(s.updater, s0) b0_numeric = _make_numeric(b0, pomdp) B = Set([b0]) B_numeric = [b0_numeric] diff --git a/test/circular_tests.jl b/test/circular_tests.jl index 01a977c..41a4e05 100644 --- a/test/circular_tests.jl +++ b/test/circular_tests.jl @@ -1,15 +1,21 @@ -let +@testset "CircularMaze" begin pomdp = CircularMaze(2, 5, 0.99) @test has_consistent_distributions(pomdp) # test non-exported solvers - # solvers = [ - # MCTSSolver(n_iterations=10, depth=5, exploration_constant=5.0), - # ] - # @testset "$solver" for solver in solvers - # @test test_solver(solver, pomdp) - # # @test_nowarn test_solver(solver, pomdp) - # end + @testset "Solvers" begin + @testset "MCTS" begin + solver = MCTSSolver(n_iterations=10, depth=5, exploration_constant=5.0) + @test_nowarn test_solver(solver, pomdp) + end + + # @testset "CompressedSolver" begin + # # TODO: compressed solver + # solver = MCTSSolver(n_iterations=10, depth=5, exploration_constant=5.0) + # # @test_nowarn test_solver(solver, pomdp) + # end + end + # test CompressedBeliefSolver @testset "Samplers" begin @@ -17,10 +23,10 @@ let @test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=PolicySampler(pomdp)), pomdp) end @testset "ExplorationPolicySampler" begin - @test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=PolicySampler(pomdp)), pomdp) + @test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=ExplorationPolicySampler(pomdp)), pomdp) end @testset "BeliefExpansionSampler" begin - @test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=PolicySampler(pomdp)), pomdp) + @test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=BeliefExpansionSampler(pomdp)), pomdp) end end diff --git a/test/solver_tests.jl b/test/solver_tests.jl index 48f66d1..93fc6ed 100644 --- a/test/solver_tests.jl +++ b/test/solver_tests.jl @@ -9,10 +9,10 @@ @test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=PolicySampler(pomdp)), pomdp) end @testset "ExplorationPolicySampler" begin - @test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=PolicySampler(pomdp)), pomdp) + @test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=ExplorationPolicySampler(pomdp)), pomdp) end @testset "BeliefExpansionSampler" begin - @test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=PolicySampler(pomdp)), pomdp) + @test_nowarn test_solver(CompressedBeliefSolver(pomdp; sampler=BeliefExpansionSampler(pomdp)), pomdp) end end end