From 59d34ae2ee0c9e406ed37f59fcfc7a5867c916b5 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Thu, 6 Jun 2019 09:22:48 -0700 Subject: [PATCH] implemented initial distributions and terminal states (#2) --- src/discrete_explicit.jl | 66 +++++++++++++++++++++++++-------------- test/discrete_explicit.jl | 9 ++++++ 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/src/discrete_explicit.jl b/src/discrete_explicit.jl index 6b2d93a..c0ba57a 100644 --- a/src/discrete_explicit.jl +++ b/src/discrete_explicit.jl @@ -1,4 +1,4 @@ -struct DiscreteExplicitPOMDP{S,A,O,OF,RF} <: POMDP{S,A,O} +struct DiscreteExplicitPOMDP{S,A,O,OF,RF,D} <: POMDP{S,A,O} s::Vector{S} a::Vector{A} o::Vector{O} @@ -10,9 +10,11 @@ struct DiscreteExplicitPOMDP{S,A,O,OF,RF} <: POMDP{S,A,O} amap::Dict{A,Int} omap::Dict{O,Int} discount::Float64 + initial::D + terminals::Set{S} end -struct DiscreteExplicitMDP{S,A,RF} <: MDP{S,A} +struct DiscreteExplicitMDP{S,A,RF,D} <: MDP{S,A} s::Vector{S} a::Vector{A} tds::Dict{Tuple{S,A}, SparseCat{Vector{S}, Vector{Float64}}} @@ -20,6 +22,8 @@ struct DiscreteExplicitMDP{S,A,RF} <: MDP{S,A} smap::Dict{S,Int} amap::Dict{A,Int} discount::Float64 + initial::D + terminals::Set{S} end const DEP = DiscreteExplicitPOMDP @@ -42,38 +46,51 @@ POMDPs.transition(m::DE, s, a) = m.tds[s,a] POMDPs.observation(m::DEP, a, sp) = m.ods[a,sp] POMDPs.reward(m::DE, s, a) = m.r(s, a) -POMDPs.initialstate_distribution(m::DEP) = uniform_belief(m) -# XXX hack -POMDPs.initialstate_distribution(m::DiscreteExplicitMDP) = uniform_belief(FullyObservablePOMDP(m)) +POMDPs.initialstate_distribution(m::DE) = m.initial + +POMDPs.isterminal(m::DE,s) = s in m.terminals + +#= +POMDPs.convert_s(::Type{V}, s::W, m::DE) where {V<:AbstractArray,W<:AbstractArray} = +POMDPs.convert_s(::Type{V}, s::W, m::DE) where {V<:AbstractVector} = convert_to_vec(V, s, m.smap) +POMDPs.convert_s(::Type{V}, s::W, m::DE) where {V<:AbstractVector} = convert_to_vec(V, s, m.smap) + +POMDPs.convert_s(::Type{V}, s, m::DE) where {V<:AbstractArray} = convert_to_vec(V, s, m.smap) +POMDPs.convert_a(::Type{V}, a, m::DE) where {V<:AbstractArray} = convert_to_vec(V, a, m.amap) +POMDPs.convert_o(::Type{V}, o, m::DEP) where {V<:AbstractArray} = convert_to_vec(V, o, m.omap) +POMDPs.convert_s(::Type{} +=# + +#= +convert_to_vec(V, x, map) = convert(V, [map[x]]) +convert_from_vec(T, v, space) = convert(T, space[convert(Integer, first(v))]) +=# POMDPModelTools.ordered_states(m::DE) = m.s POMDPModelTools.ordered_actions(m::DE) = m.a POMDPModelTools.ordered_observations(m::DEP) = m.o -# TODO reward(m, s, a) -# TODO support O(s, a, sp, o) -# TODO initial state distribution -# TODO convert_s, etc, dimensions -# TODO better errors if T or Z return something unexpected - """ - DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ) + DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,[b₀],[terminal=Set()]) Create a POMDP defined by the tuple (S,A,O,T,Z,R,γ). # Arguments +## Required - `S`,`A`,`O`: State, action, and observation spaces (typically `Vector`s) - `T::Function`: Transition probability distribution function; ``T(s,a,s')`` is the probability of transitioning to state ``s'`` from state ``s`` after taking action ``a``. - `Z::Function`: Observation probability distribution function; ``O(a, s', o)`` is the probability of receiving observation ``o`` when state ``s'`` is reached after action ``a``. - `R::Function`: Reward function; ``R(s,a)`` is the reward for taking action ``a`` in state ``s``. - `γ::Float64`: Discount factor. -# Notes -- The default initial state distribution is uniform across all states. Changing this is not yet supported, but it can be overridden for simulations. -- Terminal states are not yet supported, but absorbing states with zero reward can be used. +## Optional +- `b₀=Uniform(S)`: Initial belief/state distribution (See `POMDPModelTools.Deterministic` and `POMDPModelTools.SparseCat` for other options). + +## Keyword +- `terminal=Set()`: Set of terminal states. Once a terminal state is reached, no more actions can be taken or reward received. """ -function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount) +function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount, b0=Uniform(s)) ss = vec(collect(s)) as = vec(collect(a)) os = vec(collect(o)) @@ -107,7 +124,7 @@ function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount) Dict(ss[i]=>i for i in 1:length(ss)), Dict(as[i]=>i for i in 1:length(as)), Dict(os[i]=>i for i in 1:length(os)), - discount + discount, b0, terminal ) probability_check(m) @@ -116,22 +133,25 @@ function DiscreteExplicitPOMDP(s, a, o, t, z, r, discount) end """ - DiscreteExplicitMDP(S,A,T,R,γ) + DiscreteExplicitMDP(S,A,T,R,γ,[p₀]) Create an MDP defined by the tuple (S,A,T,R,γ). # Arguments +## Required - `S`,`A`: State and action spaces (typically `Vector`s) - `T::Function`: Transition probability distribution function; ``T(s,a,s')`` is the probability of transitioning to state ``s'`` from state ``s`` after taking action ``a``. - `R::Function`: Reward function; ``R(s,a)`` is the reward for taking action ``a`` in state ``s``. - `γ::Float64`: Discount factor. -# Notes -- The default initial state distribution is uniform across all states. Changing this is not yet supported, but it can be overridden for simulations. -- Terminal states are not yet supported, but absorbing states with zero reward can be used. +## Optional +- `p₀=Uniform(S)`: Initial state distribution (See `POMDPModelTools.Deterministic` and `POMDPModelTools.SparseCat` for other options). + +## Keyword +- `terminal=Set()`: Set of terminal states. Once a terminal state is reached, no more actions can be taken or reward received. """ -function DiscreteExplicitMDP(s, a, t, r, discount) +function DiscreteExplicitMDP(s, a, t, r, discount, p0=Uniform(s); terminal=Set()) ss = vec(collect(s)) as = vec(collect(a)) @@ -141,7 +161,7 @@ function DiscreteExplicitMDP(s, a, t, r, discount) ss, as, tds, r, Dict(ss[i]=>i for i in 1:length(ss)), Dict(as[i]=>i for i in 1:length(as)), - discount + discount, p0, terminal ) trans_prob_consistency_check(m) diff --git a/test/discrete_explicit.jl b/test/discrete_explicit.jl index d52d7ae..5c6ed74 100644 --- a/test/discrete_explicit.jl +++ b/test/discrete_explicit.jl @@ -47,12 +47,18 @@ end println("Undiscounted reward was $rsum.") @test rsum == -10.0 + + dm = DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,Deterministic(:left)) + @test initialstate(dm, Random.GLOBAL_RNG) == :left + tm = DiscreteExplicitPOMDP(S,A,O,T,Z,R,γ,terminal=Set(S)) + @test isterminal(tm, initialstate(tm, Random.GLOBAL_RNG)) end @testset "Discrete Explicit MDP" begin S = 1:5 A = [-1, 1] γ = 0.95 + p₀ = Deterministic(1) function T(s, a, sp) if sp == clamp(s+a,1,5) @@ -73,6 +79,9 @@ end end m = DiscreteExplicitMDP(S,A,T,R,γ) + m = DiscreteExplicitMDP(S,A,T,R,γ,p₀) + m = DiscreteExplicitMDP(S,A,T,R,γ,p₀,terminal=Set(5)) + @test isterminal(m, 5) solver = FunctionSolver(x->1) policy = solve(solver, m)