Skip to content

Commit

Permalink
attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed May 13, 2024
1 parent 2b39b73 commit 0e51536
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
6 changes: 5 additions & 1 deletion src/CompressedBeliefMDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ include("compressors/manifold_compressors.jl")
include("compressors/autoencoders.jl")
include("compressors/vae.jl")


export
Sampler,
sample,
Expand All @@ -58,4 +57,9 @@ export
solve
include("solver.jl")

export
CircularCorridorPOMDP,
CircularCorridorState
include("envs/circular.jl")

end # module CompressedBeliefMDPs
42 changes: 33 additions & 9 deletions src/envs/circular_corridors.jl → src/envs/circular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ function CircularCorridorPOMDP(; corridor_length::Integer=200, num_corridors::In
end


# actions
const LEFT = 0
const RIGHT = 1
const SENSE_CORRIDOR = 2
Expand All @@ -44,32 +43,50 @@ function POMDPs.actions(::CircularCorridorPOMDP)
end


function POMDPs.initial_belief(p::CircularCorridorPOMDP)
function POMDPs.actionindex(::CircularCorridorPOMDP, a::Integer)
return a
end


function POMDPs.initialstate(p::CircularCorridorPOMDP)
num_states = p.num_corridors * p.corridor_length
belief = DiscreteBelief(num_states)
return belief
end


function POMDPs.states(p::CircularCorridorPOMDP)
S = []
# TODO: need to include terminal?
space = []
for row in 1:p.num_corridors
for index in 1:p.corridor_length
state = CircularCorridorState(row, index)
push!(S, state)
push!(space, state)
end
end
return S
return space
end


function POMDPs.stateindex(p::CircularCorridorPOMDP, s::CircularCorridorState)
i = p.corridor_length * s.row + s.index
i = p.corridor_length * (s.row - 1) + s.index
return i
end


function POMDPs.stateindex(p::CircularCorridorPOMDP, ::TerminalState)
i = p.corridor_length * p.num_corridors + 1
return i
end


function _sample_distribution(p::CircularCorridorsPOMDP, rng)
function POMDPs.observations(p::CircularCorridorPOMDP)
space = states(p)
return space
end


function _sample_distribution(p::CircularCorridorPOMDP, rng)
sample = rand(rng, p.distribution)
min_ = minimum(p.distribution)
max_ = maximum(p.distribution)
Expand All @@ -80,7 +97,8 @@ function _sample_distribution(p::CircularCorridorsPOMDP, rng)
end


function POMDPs.observation(p::CircularCorridorsPOMDP, a::Integer, sp::CircularCorridorsState)
function POMDPs.observation(p::CircularCorridorPOMDP, a::Integer, sp::CircularCorridorState)
# TODO: redo this
ImplicitDistribution() do rng
if a == SENSE_CORRIDOR
obs = sp.row
Expand All @@ -96,12 +114,13 @@ function POMDPs.observation(p::CircularCorridorsPOMDP, a::Integer, sp::CircularC
end


function POMDPs.transition(p::CircularCorridorsPOMDP, s::CircularCorridorState, a)
function POMDPs.transition(p::CircularCorridorPOMDP, s::CircularCorridorState, a)
ImplicitDistribution() do rng
if a == DECLARE_GOAL
sp = TerminalState()
else
if a == LEFT
# FIXME: abs is wrong
μ = abs(s.index - 1) % p.corridor_length
elseif a == RIGHT
μ = (s.index + 1) % p.corridor_length
Expand All @@ -126,3 +145,8 @@ function POMDPs.reward(p::CircularCorridorPOMDP, s::CircularCorridorState, a)
end
return r
end


function POMDPs.discount(p::CircularCorridorPOMDP)
return p.discount_factor
end

0 comments on commit 0e51536

Please sign in to comment.