Skip to content

Commit

Permalink
implemented initial distributions and terminal states (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Jun 6, 2019
1 parent e0d2596 commit 59d34ae
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 23 deletions.
66 changes: 43 additions & 23 deletions src/discrete_explicit.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct DiscreteExplicitPOMDP{S,A,O,OF,RF} <: POMDP{S,A,O}
struct DiscreteExplicitPOMDP{S,A,O,OF,RF,D} <: POMDP{S,A,O}
s::Vector{S}
a::Vector{A}
o::Vector{O}
Expand All @@ -10,16 +10,20 @@ struct DiscreteExplicitPOMDP{S,A,O,OF,RF} <: POMDP{S,A,O}
amap::Dict{A,Int}
omap::Dict{O,Int}
discount::Float64
initial::D
terminals::Set{S}
end

struct DiscreteExplicitMDP{S,A,RF} <: MDP{S,A}
struct DiscreteExplicitMDP{S,A,RF,D} <: MDP{S,A}
s::Vector{S}
a::Vector{A}
tds::Dict{Tuple{S,A}, SparseCat{Vector{S}, Vector{Float64}}}
r::RF
smap::Dict{S,Int}
amap::Dict{A,Int}
discount::Float64
initial::D
terminals::Set{S}
end

const DEP = DiscreteExplicitPOMDP
Expand All @@ -42,38 +46,51 @@ POMDPs.transition(m::DE, s, a) = m.tds[s,a]
POMDPs.observation(m::DEP, a, sp) = m.ods[a,sp]
POMDPs.reward(m::DE, s, a) = m.r(s, a)

POMDPs.initialstate_distribution(m::DEP) = uniform_belief(m)
# XXX hack
POMDPs.initialstate_distribution(m::DiscreteExplicitMDP) = uniform_belief(FullyObservablePOMDP(m))
POMDPs.initialstate_distribution(m::DE) = m.initial

POMDPs.isterminal(m::DE,s) = s in m.terminals

#=
POMDPs.convert_s(::Type{V}, s::W, m::DE) where {V<:AbstractArray,W<:AbstractArray} =
POMDPs.convert_s(::Type{V}, s::W, m::DE) where {V<:AbstractVector} = convert_to_vec(V, s, m.smap)
POMDPs.convert_s(::Type{V}, s::W, m::DE) where {V<:AbstractVector} = convert_to_vec(V, s, m.smap)
POMDPs.convert_s(::Type{V}, s, m::DE) where {V<:AbstractArray} = convert_to_vec(V, s, m.smap)
POMDPs.convert_a(::Type{V}, a, m::DE) where {V<:AbstractArray} = convert_to_vec(V, a, m.amap)
POMDPs.convert_o(::Type{V}, o, m::DEP) where {V<:AbstractArray} = convert_to_vec(V, o, m.omap)
POMDPs.convert_s(::Type{}
=#

#=
convert_to_vec(V, x, map) = convert(V, [map[x]])
convert_from_vec(T, v, space) = convert(T, space[convert(Integer, first(v))])
=#

POMDPModelTools.ordered_states(m::DE) = m.s
POMDPModelTools.ordered_actions(m::DE) = m.a
POMDPModelTools.ordered_observations(m::DEP) = m.o

# TODO reward(m, s, a)
# TODO support O(s, a, sp, o)
# TODO initial state distribution
# TODO convert_s, etc, dimensions
# TODO better errors if T or Z return something unexpected

"""
DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ)
DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,[b₀],[terminal=Set()])
Create a POMDP defined by the tuple (S,A,O,T,Z,R,γ).
# Arguments
## Required
- `S`,`A`,`O`: State, action, and observation spaces (typically `Vector`s)
- `T::Function`: Transition probability distribution function; ``T(s,a,s')`` is the probability of transitioning to state ``s'`` from state ``s`` after taking action ``a``.
- `Z::Function`: Observation probability distribution function; ``O(a, s', o)`` is the probability of receiving observation ``o`` when state ``s'`` is reached after action ``a``.
- `R::Function`: Reward function; ``R(s,a)`` is the reward for taking action ``a`` in state ``s``.
- `γ::Float64`: Discount factor.
# Notes
- The default initial state distribution is uniform across all states. Changing this is not yet supported, but it can be overridden for simulations.
- Terminal states are not yet supported, but absorbing states with zero reward can be used.
## Optional
- `b₀=Uniform(S)`: Initial belief/state distribution (See `POMDPModelTools.Deterministic` and `POMDPModelTools.SparseCat` for other options).
## Keyword
- `terminal=Set()`: Set of terminal states. Once a terminal state is reached, no more actions can be taken or reward received.
"""
function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount)
function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount, b0=Uniform(s))
ss = vec(collect(s))
as = vec(collect(a))
os = vec(collect(o))
Expand Down Expand Up @@ -107,7 +124,7 @@ function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount)
Dict(ss[i]=>i for i in 1:length(ss)),
Dict(as[i]=>i for i in 1:length(as)),
Dict(os[i]=>i for i in 1:length(os)),
discount
discount, b0, terminal
)

probability_check(m)
Expand All @@ -116,22 +133,25 @@ function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount)
end

"""
DiscreteExplicitMDP(S,A,T,R,γ)
DiscreteExplicitMDP(S,A,T,R,γ,[p₀])
Create an MDP defined by the tuple (S,A,T,R,γ).
# Arguments
## Required
- `S`,`A`: State and action spaces (typically `Vector`s)
- `T::Function`: Transition probability distribution function; ``T(s,a,s')`` is the probability of transitioning to state ``s'`` from state ``s`` after taking action ``a``.
- `R::Function`: Reward function; ``R(s,a)`` is the reward for taking action ``a`` in state ``s``.
- `γ::Float64`: Discount factor.
# Notes
- The default initial state distribution is uniform across all states. Changing this is not yet supported, but it can be overridden for simulations.
- Terminal states are not yet supported, but absorbing states with zero reward can be used.
## Optional
- `p₀=Uniform(S)`: Initial state distribution (See `POMDPModelTools.Deterministic` and `POMDPModelTools.SparseCat` for other options).
## Keyword
- `terminal=Set()`: Set of terminal states. Once a terminal state is reached, no more actions can be taken or reward received.
"""
function DiscreteExplicitMDP(s, a, t, r, discount)
function DiscreteExplicitMDP(s, a, t, r, discount, p0=Uniform(s); terminal=Set())
ss = vec(collect(s))
as = vec(collect(a))

Expand All @@ -141,7 +161,7 @@ function DiscreteExplicitMDP(s, a, t, r, discount)
ss, as, tds, r,
Dict(ss[i]=>i for i in 1:length(ss)),
Dict(as[i]=>i for i in 1:length(as)),
discount
discount, p0, terminal
)

trans_prob_consistency_check(m)
Expand Down
9 changes: 9 additions & 0 deletions test/discrete_explicit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,18 @@
end
println("Undiscounted reward was $rsum.")
@test rsum == -10.0

dm = DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,Deterministic(:left))
@test initialstate(dm, Random.GLOBAL_RNG) == :left
tm = DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,terminal=Set(S))
@test isterminal(tm, initialstate(tm, Random.GLOBAL_RNG))
end

@testset "Discrete Explicit MDP" begin
S = 1:5
A = [-1, 1]
γ = 0.95
p₀ = Deterministic(1)

function T(s, a, sp)
if sp == clamp(s+a,1,5)
Expand All @@ -73,6 +79,9 @@ end
end

m = DiscreteExplicitMDP(S,A,T,R,γ)
m = DiscreteExplicitMDP(S,A,T,R,γ,p₀)
m = DiscreteExplicitMDP(S,A,T,R,γ,p₀,terminal=Set(5))
@test isterminal(m, 5)

solver = FunctionSolver(x->1)
policy = solve(solver, m)
Expand Down

0 comments on commit 59d34ae

Please sign in to comment.