diff --git a/src/envs/circular.jl b/src/envs/circular.jl index efa4d10..b8a48d8 100644 --- a/src/envs/circular.jl +++ b/src/envs/circular.jl @@ -1,6 +1,6 @@ struct CircularMazeState - corridor::Integer # corridor number - x::Integer # position in corridor + corridor::Integer # corridor number ∈ [1, ..., n_corridors] + x::Integer # position in corridor ∈ [1, ..., corridor_length] end struct CircularMaze <: POMDP{Union{CircularMazeState, TerminalState}, Integer, Integer} @@ -9,6 +9,9 @@ struct CircularMaze <: POMDP{Union{CircularMazeState, TerminalState}, Integer, I probabilities::AbstractArray # probability masses for creating von Mises distributions center::Integer discount::Float64 + + r_findgoal::Float64 + goals::AbstractArray end function _get_mass(d, x1, x2) @@ -19,14 +22,11 @@ function _get_mass(d, x1, x2) return m end -function CircularMaze( - n_corridors::Integer, - corridor_length::Integer, - discount::Float64, -) - d = VonMises() - min_ = minimum(d) - max_ = maximum(d) +# get the probability masses for each state in a discretized von Mises distribution +function _make_probabilities(corridor_length) + d = VonMises() # von Mises distribution with zero mean and unit concentration + min_ = minimum(d) # default to -π + max_ = maximum(d) # defaults to π step = (max_ - min_) / corridor_length probabilities = [] if corridor_length % 2 == 0 # offset indices by step / 2 when corridor_length is even @@ -44,8 +44,18 @@ function CircularMaze( push!(probabilities, m) end end + return probabilities +end + +function CircularMaze( + n_corridors::Integer, + corridor_length::Integer, + discount::Float64, +) + probabilities = _make_probabilities(corridor_length) center = div(corridor_length, 2) + 1 - pomdp = CircularMaze(n_corridors, corridor_length, probabilities, center, discount) + goals = [] # TODO: fix this + pomdp = CircularMaze(n_corridors, corridor_length, probabilities, center, discount, goals) return pomdp end @@ -59,13 +69,15 @@ const RIGHT = 1 const SENSE_CORRIDOR = 2 const DECLARE_GOAL = 3 +const ACTIONS = [ + LEFT, + RIGHT, + SENSE_CORRIDOR, + DECLARE_GOAL +] + function POMDPs.actions(::CircularMaze) - A = [ - LEFT, - RIGHT, - SENSE_CORRIDOR, - DECLARE_GOAL - ] + A = ACTIONS return A end @@ -86,8 +98,8 @@ function POMDPs.states(pomdp::CircularMaze) return space end -function POMDPs.stateindex(::CircularMaze, s::CircularMazeState) - index = (s.corridor - 1) + s.x +function POMDPs.stateindex(pomdp::CircularMaze, s::CircularMazeState) + index = (s.corridor - 1) * pomdp.corridor_length + s.x return index end @@ -96,41 +108,78 @@ function POMDPs.stateindex(pomdp::CircularMaze, ::TerminalState) return index end +# the initial state distribution is a von Mises distributions each over corridor with a mean at the center function POMDPs.initialstate(pomdp::CircularMaze) - # TODO - return nothing -end - -function _make_sparse_cat(pomdp::CircularMaze, x::Integer) + probabilities = repeat(pomdp.probabilities, pomdp.n_corridors) + probabilities /= pomdp.n_corridors # normalize values to sum to 1 values = states(pomdp) - shifts = pomdp.center - s.x - probabilities = circshift(pomdp.probabilities, shifts) d = SparseCat(values, probabilities) return d end -function POMDPs.observation(pomdp::CircularMaze, s::CircularMazeState, a::Integer) +function _center_probabilities(pomdp::CircularMaze, x::Integer) + shifts = pomdp.center - x + centered_probabilities = circshift(pomdp.probabilities, shifts) + return centered_probabilities +end + +# return a discretized von Mises distribution with a center at x +# function _make_sparse_cat_over_single_corridor(pomdp::CircularMaze, x::Integer) +# values = states(pomdp) +# shifts = pomdp.center - x +# probabilities = circshift(pomdp.probabilities, shifts) +# d = SparseCat(values, probabilities) # NOTE: we use SparseCat b/c it makes it easy to define a distribution sample as a CircularMazeState +# return d +# end + +# # return a distribution over all possible states +# function _make_sparse_cat_over_all_corridors(pomdp::CircularMaze, x::Integer) +# values = states(pomdp) +# shifts = pomdp.center - x +# probabilities = circshift(pomdp.probabilities, shifts) +# probabilities = repeat(probabilities, pomdp.n_corridors) +# probabilities /= pomdp.n_corridors # normalize values to sum to 1 +# d = SparseCat(values, probabilities) +# return d +# end + +# observations identify the current state modulo 100 with a mean equal to the true state s.x (modulo 100) +function POMDPs.observation(pomdp::CircularMaze, s::CircularMazeState, a::Integer, sp::CircularMazeState) if a == SENSE_CORRIDOR - obs = s.corridor + obs = Deterministic(s.corridor) else - obs = _make_sparse_cat(pomdp, s.x) + # corridor = s.corridor + # corridor_states = [] + # for x ∈ 1:pomdp.corridor_length + # s_ = CircularMazeState(corridor, x) + # push!(states, s_) + # end + values = 1:pomdp.corridor_length + probabilities = _center_probabilities(pomdp, s.x) + d = SparseCat(values, probabilities) + obs = d end return obs end function POMDPs.observations(pomdp::CircularMaze) + # NOTE: In JuliaPOMDPs, an observation space is NOT the set of possible distributions, but rather union of the support of all possible observations corridors = 1:pomdp.n_corridors # from SENSE_CORRIDOR - distribution_observations = permutations(pomdp.probabilities) - space = chain(corridors, distribution_observations) # generator + perms = permutations(pomdp.probabilities) + space = chain(corridors, perms) # generator return space end # TODO: maybe implement POMDPs.obsindex + function POMDPs.transition(pomdp::CircularMaze, s::CircularMazeState, a::Integer) + @assert a in actions(pomdp) "Unrecognized action $a" if a == DECLARE_GOAL + # env resets when goal is declared regardless of whether agent is actually at the goal d = Deterministic(terminalstate) else + # move left/right with some von Mises noise if a == LEFT x = s.x - 1 if x < 1 @@ -141,8 +190,24 @@ function POMDPs.transition(pomdp::CircularMaze, s::CircularMazeState, a::Integer else x = s.x end - d = _make_sparse_cat(pomdp, x) + corridor = s.corridor + corridor_states = [] + for x_ ∈ 1:pomdp.corridor_length + s_ = CircularMazeState(corridor, x_) + push!(states, s_) + end + probabilities = _center_probabilities(pomdp, x) + d = SparseCat(corridor_states, probabilities) end return d end - \ No newline at end of file + +function POMDPs.reward(pomdp::CircularMaze, s::Union{CircularMaze, TerminalState}, a::Integer) + @assert a in actions(pomdp) "Unrecognized action $a" + if s ∈ pomdp.goals && a == DECLARE_GOAL + r = pomdp.r_findgoal + else + r = 0 + end + return r +end diff --git a/test/circular_tests.jl b/test/circular_tests.jl new file mode 100644 index 0000000..cb48187 --- /dev/null +++ b/test/circular_tests.jl @@ -0,0 +1,42 @@ +# using POMDPModels +# using POMDPTools +# using Test + +# let +# pomdp = TigerPOMDP() + +# pomdp2 = TabularPOMDP(T, R, O, 0.95) + +# policy = RandomPolicy(pomdp, rng=MersenneTwister(2)) +# sim = RolloutSimulator(rng=MersenneTwister(3), max_steps=100) + +# simulate(sim, pomdp1, policy, updater(policy), initialstate(pomdp1)) + +# o = last(observations(pomdp1)) +# @test o == 1 +# # test vec +# ov = convert_o(Array{Float64}, true, pomdp1) +# @test ov == [1.] +# o = convert_o(Bool, ov, pomdp1) +# @test o == true + +# @test has_consistent_distributions(pomdp) + +# @test reward(pomdp, TIGER_LEFT, TIGER_OPEN_LEFT) == pomdp1.r_findtiger +# @test reward(pomdp1, TIGER_LEFT, TIGER_OPEN_RIGHT) == pomdp1.r_escapetiger +# @test reward(pomdp1, TIGER_RIGHT, TIGER_OPEN_RIGHT) == pomdp1.r_findtiger +# @test reward(pomdp1, TIGER_RIGHT, TIGER_OPEN_LEFT) == pomdp1.r_escapetiger +# @test reward(pomdp1, TIGER_RIGHT, TIGER_LISTEN) == pomdp1.r_listen + +# for s in states(pomdp1) +# @test pdf(transition(pomdp1, s, TIGER_LISTEN), s) == 1.0 +# @test pdf(transition(pomdp1, s, TIGER_OPEN_LEFT), s) == 0.5 +# @test pdf(transition(pomdp1, s, TIGER_OPEN_RIGHT), s) == 0.5 +# end + +# for s in states(pomdp1) +# @test pdf(observation(pomdp1, TIGER_LISTEN, s), s) == pomdp1.p_listen_correctly +# @test pdf(observation(pomdp1, TIGER_OPEN_LEFT, s), s) == 0.5 +# @test pdf(observation(pomdp1, TIGER_OPEN_RIGHT, s), s) == 0.5 +# end +# end \ No newline at end of file