Skip to content

Commit

Permalink
updating corridor
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed Jun 23, 2024
1 parent 3644928 commit 69684a2
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 33 deletions.
131 changes: 98 additions & 33 deletions src/envs/circular.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct CircularMazeState
corridor::Integer # corridor number
x::Integer # position in corridor
corridor::Integer # corridor number ∈ [1, ..., n_corridors]
x::Integer # position in corridor ∈ [1, ..., corridor_length]
end

struct CircularMaze <: POMDP{Union{CircularMazeState, TerminalState}, Integer, Integer}
Expand All @@ -9,6 +9,9 @@ struct CircularMaze <: POMDP{Union{CircularMazeState, TerminalState}, Integer, I
probabilities::AbstractArray # probability masses for creating von Mises distributions
center::Integer
discount::Float64

r_findgoal::Float64
goals::AbstractArray
end

function _get_mass(d, x1, x2)
Expand All @@ -19,14 +22,11 @@ function _get_mass(d, x1, x2)
return m
end

function CircularMaze(
n_corridors::Integer,
corridor_length::Integer,
discount::Float64,
)
d = VonMises()
min_ = minimum(d)
max_ = maximum(d)
# get the probability masses for each state in a discretized von Mises distribution
function _make_probabilities(corridor_length)
d = VonMises() # von Mises distribution with zero mean and unit concentration
min_ = minimum(d) # default to -π
max_ = maximum(d) # defaults to π
step = (max_ - min_) / corridor_length
probabilities = []
if corridor_length % 2 == 0 # offset indices by step / 2 when corridor_length is even
Expand All @@ -44,8 +44,18 @@ function CircularMaze(
push!(probabilities, m)
end
end
return probabilities
end

function CircularMaze(
n_corridors::Integer,
corridor_length::Integer,
discount::Float64,
)
probabilities = _make_probabilities(corridor_length)
center = div(corridor_length, 2) + 1
pomdp = CircularMaze(n_corridors, corridor_length, probabilities, center, discount)
goals = [] # TODO: fix this
pomdp = CircularMaze(n_corridors, corridor_length, probabilities, center, discount, goals)
return pomdp
end

Expand All @@ -59,13 +69,15 @@ const RIGHT = 1
const SENSE_CORRIDOR = 2
const DECLARE_GOAL = 3

const ACTIONS = [
LEFT,
RIGHT,
SENSE_CORRIDOR,
DECLARE_GOAL
]

function POMDPs.actions(::CircularMaze)
A = [
LEFT,
RIGHT,
SENSE_CORRIDOR,
DECLARE_GOAL
]
A = ACTIONS
return A
end

Expand All @@ -86,8 +98,8 @@ function POMDPs.states(pomdp::CircularMaze)
return space
end

function POMDPs.stateindex(::CircularMaze, s::CircularMazeState)
index = (s.corridor - 1) + s.x
function POMDPs.stateindex(pomdp::CircularMaze, s::CircularMazeState)
index = (s.corridor - 1) * pomdp.corridor_length + s.x
return index
end

Expand All @@ -96,41 +108,78 @@ function POMDPs.stateindex(pomdp::CircularMaze, ::TerminalState)
return index
end

# the initial state distribution is a von Mises distributions each over corridor with a mean at the center
function POMDPs.initialstate(pomdp::CircularMaze)
# TODO
return nothing
end

function _make_sparse_cat(pomdp::CircularMaze, x::Integer)
probabilities = repeat(pomdp.probabilities, pomdp.n_corridors)
probabilities /= pomdp.n_corridors # normalize values to sum to 1
values = states(pomdp)
shifts = pomdp.center - s.x
probabilities = circshift(pomdp.probabilities, shifts)
d = SparseCat(values, probabilities)
return d
end

function POMDPs.observation(pomdp::CircularMaze, s::CircularMazeState, a::Integer)
function _center_probabilities(pomdp::CircularMaze, x::Integer)
shifts = pomdp.center - x
centered_probabilities = circshift(pomdp.probabilities, shifts)
return centered_probabilities
end

# return a discretized von Mises distribution with a center at x
# function _make_sparse_cat_over_single_corridor(pomdp::CircularMaze, x::Integer)
# values = states(pomdp)
# shifts = pomdp.center - x
# probabilities = circshift(pomdp.probabilities, shifts)
# d = SparseCat(values, probabilities) # NOTE: we use SparseCat b/c it makes it easy to define a distribution sample as a CircularMazeState
# return d
# end

# # return a distribution over all possible states
# function _make_sparse_cat_over_all_corridors(pomdp::CircularMaze, x::Integer)
# values = states(pomdp)
# shifts = pomdp.center - x
# probabilities = circshift(pomdp.probabilities, shifts)
# probabilities = repeat(probabilities, pomdp.n_corridors)
# probabilities /= pomdp.n_corridors # normalize values to sum to 1
# d = SparseCat(values, probabilities)
# return d
# end

# observations identify the current state modulo 100 with a mean equal to the true state s.x (modulo 100)
function POMDPs.observation(pomdp::CircularMaze, s::CircularMazeState, a::Integer, sp::CircularMazeState)
if a == SENSE_CORRIDOR
obs = s.corridor
obs = Deterministic(s.corridor)
else
obs = _make_sparse_cat(pomdp, s.x)
# corridor = s.corridor
# corridor_states = []
# for x ∈ 1:pomdp.corridor_length
# s_ = CircularMazeState(corridor, x)
# push!(states, s_)
# end
values = 1:pomdp.corridor_length
probabilities = _center_probabilities(pomdp, s.x)
d = SparseCat(values, probabilities)
obs = d
end
return obs
end

function POMDPs.observations(pomdp::CircularMaze)
# NOTE: In JuliaPOMDPs, an observation space is NOT the set of possible distributions, but rather union of the support of all possible observations
corridors = 1:pomdp.n_corridors # from SENSE_CORRIDOR
distribution_observations = permutations(pomdp.probabilities)
space = chain(corridors, distribution_observations) # generator
perms = permutations(pomdp.probabilities)
space = chain(corridors, perms) # generator
return space
end

# TODO: maybe implement POMDPs.obsindex


function POMDPs.transition(pomdp::CircularMaze, s::CircularMazeState, a::Integer)
@assert a in actions(pomdp) "Unrecognized action $a"
if a == DECLARE_GOAL
# env resets when goal is declared regardless of whether agent is actually at the goal
d = Deterministic(terminalstate)
else
# move left/right with some von Mises noise
if a == LEFT
x = s.x - 1
if x < 1
Expand All @@ -141,8 +190,24 @@ function POMDPs.transition(pomdp::CircularMaze, s::CircularMazeState, a::Integer
else
x = s.x
end
d = _make_sparse_cat(pomdp, x)
corridor = s.corridor
corridor_states = []
for x_ 1:pomdp.corridor_length
s_ = CircularMazeState(corridor, x_)
push!(states, s_)
end
probabilities = _center_probabilities(pomdp, x)
d = SparseCat(corridor_states, probabilities)
end
return d
end


function POMDPs.reward(pomdp::CircularMaze, s::Union{CircularMaze, TerminalState}, a::Integer)
@assert a in actions(pomdp) "Unrecognized action $a"
if s pomdp.goals && a == DECLARE_GOAL
r = pomdp.r_findgoal
else
r = 0
end
return r
end
42 changes: 42 additions & 0 deletions test/circular_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# using POMDPModels
# using POMDPTools
# using Test

# let
# pomdp = TigerPOMDP()

# pomdp2 = TabularPOMDP(T, R, O, 0.95)

# policy = RandomPolicy(pomdp, rng=MersenneTwister(2))
# sim = RolloutSimulator(rng=MersenneTwister(3), max_steps=100)

# simulate(sim, pomdp1, policy, updater(policy), initialstate(pomdp1))

# o = last(observations(pomdp1))
# @test o == 1
# # test vec
# ov = convert_o(Array{Float64}, true, pomdp1)
# @test ov == [1.]
# o = convert_o(Bool, ov, pomdp1)
# @test o == true

# @test has_consistent_distributions(pomdp)

# @test reward(pomdp, TIGER_LEFT, TIGER_OPEN_LEFT) == pomdp1.r_findtiger
# @test reward(pomdp1, TIGER_LEFT, TIGER_OPEN_RIGHT) == pomdp1.r_escapetiger
# @test reward(pomdp1, TIGER_RIGHT, TIGER_OPEN_RIGHT) == pomdp1.r_findtiger
# @test reward(pomdp1, TIGER_RIGHT, TIGER_OPEN_LEFT) == pomdp1.r_escapetiger
# @test reward(pomdp1, TIGER_RIGHT, TIGER_LISTEN) == pomdp1.r_listen

# for s in states(pomdp1)
# @test pdf(transition(pomdp1, s, TIGER_LISTEN), s) == 1.0
# @test pdf(transition(pomdp1, s, TIGER_OPEN_LEFT), s) == 0.5
# @test pdf(transition(pomdp1, s, TIGER_OPEN_RIGHT), s) == 0.5
# end

# for s in states(pomdp1)
# @test pdf(observation(pomdp1, TIGER_LISTEN, s), s) == pomdp1.p_listen_correctly
# @test pdf(observation(pomdp1, TIGER_OPEN_LEFT, s), s) == 0.5
# @test pdf(observation(pomdp1, TIGER_OPEN_RIGHT, s), s) == 0.5
# end
# end

0 comments on commit 69684a2

Please sign in to comment.