diff --git a/Project.toml b/Project.toml index 5449205..c6963d7 100644 --- a/Project.toml +++ b/Project.toml @@ -5,9 +5,11 @@ version = "1.0.0" [deps] Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LocalApproximationValueIteration = "a40420fb-f401-52da-a663-f502e5b95060" LocalFunctionApproximation = "db97f5ab-fc25-52dd-a8f9-02a257c35074" diff --git a/src/CompressedBeliefMDPs.jl b/src/CompressedBeliefMDPs.jl index f5e4570..c88ee92 100644 --- a/src/CompressedBeliefMDPs.jl +++ b/src/CompressedBeliefMDPs.jl @@ -11,6 +11,8 @@ using Distributions using Bijections using NearestNeighbors using StaticArrays +using Combinatorics +using IterTools using LinearAlgebra using Random @@ -58,8 +60,8 @@ export include("solver.jl") export - CircularCorridorPOMDP, - CircularCorridorState + CircularMaze, + CircularMazeState include("envs/circular.jl") end # module CompressedBeliefMDPs diff --git a/src/envs/circular.jl b/src/envs/circular.jl index 7225d7b..efa4d10 100644 --- a/src/envs/circular.jl +++ b/src/envs/circular.jl @@ -1,38 +1,65 @@ -struct CircularCorridorState - row::Integer - index::Integer +struct CircularMazeState + corridor::Integer # corridor number + x::Integer # position in corridor end - -struct CircularCorridorPOMDP <: POMDP{Union{CircularCorridorState, TerminalState}, Integer, Integer} - corridor_length::Integer - num_corridors::Integer - distribution::VonMises - goals::AbstractArray{CircularCorridorState} - discount_factor::Float64 +struct CircularMaze <: POMDP{Union{CircularMazeState, TerminalState}, Integer, Integer} + n_corridors::Integer # number of corridors + corridor_length::Integer # corridor length + probabilities::AbstractArray # probability masses for creating von Mises distributions + center::Integer + discount::Float64 end +function _get_mass(d, x1, x2) + @assert x2 >= x1 + c1 = cdf(d, x1) + c2 = cdf(d, x2) + m = c2 - c1 + return m +end -function CircularCorridorPOMDP(; corridor_length::Integer=200, num_corridors::Integer=2, discount_factor::Float64=0.95) - goals = [] - goal_indices = rand(1:corridor_length, num_corridors) - for (row, index) in enumerate(goal_indices) - goal = CircularCorridorState(row, index) - push!(goals, goal) +function CircularMaze( + n_corridors::Integer, + corridor_length::Integer, + discount::Float64, +) + d = VonMises() + min_ = minimum(d) + max_ = maximum(d) + step = (max_ - min_) / corridor_length + probabilities = [] + if corridor_length % 2 == 0 # offset indices by step / 2 when corridor_length is even + for x1 in (min_ + step / 2):step:(max_ - 1.5 * step) + m = _get_mass(d, x1, x1 + step) + push!(probabilities, m) + end + m1 = _get_mass(d, max_ - step / 2, max_) + m2 = _get_mass(d, min_, min_ + step / 2) + m = m1 + m2 + push!(probabilities, m) + else + for x1 in min_:step:(max_ - step) + m = _get_mass(d, x1, x1 + step) + push!(probabilities, m) + end end - distribution = VonMises() - pomdp = CircularCorridorPOMDP(corridor_length, num_corridors, distribution, goals, discount_factor) + center = div(corridor_length, 2) + 1 + pomdp = CircularMaze(n_corridors, corridor_length, probabilities, center, discount) return pomdp end +function CircularMaze() + pomdp = CircularMaze(2, 200, 0.99) + return pomdp +end const LEFT = 0 const RIGHT = 1 const SENSE_CORRIDOR = 2 const DECLARE_GOAL = 3 - -function POMDPs.actions(::CircularCorridorPOMDP) +function POMDPs.actions(::CircularMaze) A = [ LEFT, RIGHT, @@ -42,111 +69,80 @@ function POMDPs.actions(::CircularCorridorPOMDP) return A end - -function POMDPs.actionindex(::CircularCorridorPOMDP, a::Integer) - return a +function POMDPs.actionindex(::CircularMaze, a::Integer) + index = a + return index end - -function POMDPs.initialstate(p::CircularCorridorPOMDP) - num_states = p.num_corridors * p.corridor_length - belief = DiscreteBelief(num_states) - return belief -end - - -function POMDPs.states(p::CircularCorridorPOMDP) - # TODO: need to include terminal? - space = [] - for row in 1:p.num_corridors - for index in 1:p.corridor_length - state = CircularCorridorState(row, index) +function POMDPs.states(pomdp::CircularMaze) + space = statetype(pomdp)[] + for i ∈ 1:pomdp.n_corridors + for j ∈ 1:pomdp.corridor_length + state = CircularMazeState(i, j) push!(space, state) end end + push!(space, terminalstate) return space end - -function POMDPs.stateindex(p::CircularCorridorPOMDP, s::CircularCorridorState) - i = p.corridor_length * (s.row - 1) + s.index - return i +function POMDPs.stateindex(::CircularMaze, s::CircularMazeState) + index = (s.corridor - 1) + s.x + return index end - -function POMDPs.stateindex(p::CircularCorridorPOMDP, ::TerminalState) - i = p.corridor_length * p.num_corridors + 1 - return i +function POMDPs.stateindex(pomdp::CircularMaze, ::TerminalState) + index = pomdp.n_corridors * pomdp.corridor_length + 1 + return index end - -function POMDPs.observations(p::CircularCorridorPOMDP) - space = states(p) - return space +function POMDPs.initialstate(pomdp::CircularMaze) + # TODO + return nothing end - -function _sample_distribution(p::CircularCorridorPOMDP, rng) - sample = rand(rng, p.distribution) - min_ = minimum(p.distribution) - max_ = maximum(p.distribution) - step = (max_ - min_) / p.corridor_length - bins = collect(min_:step:max_) - i = searchsortedfirst(bins, sample) # TODO: replace this w/ NN search? # FIXME: searchsortedfirst doesn't work as needed - return i +function _make_sparse_cat(pomdp::CircularMaze, x::Integer) + values = states(pomdp) + shifts = pomdp.center - s.x + probabilities = circshift(pomdp.probabilities, shifts) + d = SparseCat(values, probabilities) + return d end - -function POMDPs.observation(p::CircularCorridorPOMDP, a::Integer, sp::CircularCorridorState) - # TODO: redo this - ImplicitDistribution() do rng - if a == SENSE_CORRIDOR - obs = sp.row - else - # TODO: how to represent n-modal distributions?? - μ = sp.index - sample = _sample_distribution(p, rng) - index = (μ + sample) % p.corridor_length - obs = index - end - return obs +function POMDPs.observation(pomdp::CircularMaze, s::CircularMazeState, a::Integer) + if a == SENSE_CORRIDOR + obs = s.corridor + else + obs = _make_sparse_cat(pomdp, s.x) end + return obs end - -function POMDPs.transition(p::CircularCorridorPOMDP, s::CircularCorridorState, a) - ImplicitDistribution() do rng - if a == DECLARE_GOAL - sp = TerminalState() - else - if a == LEFT - # FIXME: abs is wrong - μ = abs(s.index - 1) % p.corridor_length - elseif a == RIGHT - μ = (s.index + 1) % p.corridor_length - else - μ = s.index - end - sample = _sample_distribution(p, rng) - index = (μ + sample) % p.corridor_length - row = s.row - sp = CircularCorridorState(row, index) - end - return sp - end +function POMDPs.observations(pomdp::CircularMaze) + corridors = 1:pomdp.n_corridors # from SENSE_CORRIDOR + distribution_observations = permutations(pomdp.probabilities) + space = chain(corridors, distribution_observations) # generator + return space end +# TODO: maybe implement POMDPs.obsindex -function POMDPs.reward(p::CircularCorridorPOMDP, s::CircularCorridorState, a) - if a == DECLARE_GOAL && s in p.goals - r = 1 +function POMDPs.transition(pomdp::CircularMaze, s::CircularMazeState, a::Integer) + if a == DECLARE_GOAL + d = Deterministic(terminalstate) else - r = 0 + if a == LEFT + x = s.x - 1 + if x < 1 + x = pomdp.corridor_length + end + elseif a == RIGHT + x = (s.x + 1) % pomdp.corridor_length + else + x = s.x + end + d = _make_sparse_cat(pomdp, x) end - return r + return d end - - -function POMDPs.discount(p::CircularCorridorPOMDP) - return p.discount_factor -end \ No newline at end of file + \ No newline at end of file