Skip to content

Commit

Permalink
Change state representation in POMDPCommonRLEnv
Browse files Browse the repository at this point in the history
Add RLS parameter for optional state type conversion. Return only
converted state instead of state-observation tuple in `RL.state()`
  • Loading branch information
johannes-fischer committed Aug 23, 2023
1 parent ec89812 commit ec1d1cf
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions lib/POMDPTools/src/CommonRLIntegration/to_env.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ end
"""
MDPCommonRLEnv(m, [s])
MDPCommonRLEnv{RLO}(m, [s])
Create a CommonRLInterface environment from MDP m; optionally specify the state 's'.
The `RLO` parameter can be used to specify a type to convert the observation to. By default, this is `AbstractArray`. Use `Any` to disable conversion.
Expand Down Expand Up @@ -47,7 +47,7 @@ function RL.setstate!(env::MDPCommonRLEnv{<:Any, <:Any, S}, s) where S
return nothing
end

mutable struct POMDPCommonRLEnv{RLO, M<:POMDP, S, O} <: AbstractPOMDPsCommonRLEnv
mutable struct POMDPCommonRLEnv{RLO, RLS, M<:POMDP, S, O} <: AbstractPOMDPsCommonRLEnv
m::M
s::S
o::O
Expand All @@ -56,13 +56,14 @@ end
"""
POMDPCommonRLEnv(m, [s], [o])
POMDPCommonRLEnv{RLO}(m, [s], [o])
Create a CommonRLInterface environment from POMDP m; optionally specify the state 's' and observation 'o'.
The `RLO` parameter can be used to specify a type to convert the observation to. By default, this is `AbstractArray`. Use `Any` to disable conversion.
The `RLO` and `RLS` parameters can be used to specify types to convert the observation and state to. By default, this is `AbstractArray`. Use `Any` to disable conversion.
"""
POMDPCommonRLEnv{RLO}(m, s=rand(initialstate(m)), o=rand(initialobs(m, s))) where {RLO} = POMDPCommonRLEnv{RLO, typeof(m), statetype(m), obstype(m)}(m, s, o)
POMDPCommonRLEnv(m, s=rand(initialstate(m)), o=rand(initialobs(m, s))) = POMDPCommonRLEnv{AbstractArray}(m, s, o)
POMDPCommonRLEnv{RLO,RLS}(m, s=rand(initialstate(m)), o=rand(initialobs(m, s))) where {RLO,RLS} = POMDPCommonRLEnv{RLO,RLS,typeof(m),statetype(m),obstype(m)}(m, s, o)
POMDPCommonRLEnv{RLO}(m, s=rand(initialstate(m)), o=rand(initialobs(m, s))) where {RLO} = POMDPCommonRLEnv{RLO,AbstractArray,typeof(m),statetype(m),obstype(m)}(m, s, o)
POMDPCommonRLEnv(m, s=rand(initialstate(m)), o=rand(initialobs(m, s))) = POMDPCommonRLEnv{AbstractArray,AbstractArray}(m, s, o)

function RL.reset!(env::POMDPCommonRLEnv)
env.s = rand(initialstate(env.m))
Expand All @@ -79,17 +80,17 @@ end

RL.observe(env::POMDPCommonRLEnv{RLO}) where {RLO} = convert_o(RLO, env.o, env.m)

RL.clone(env::POMDPCommonRLEnv{RLO}) where {RLO} = POMDPCommonRLEnv{RLO}(env.m, env.s, env.o)
RL.clone(env::POMDPCommonRLEnv{RLO,RLS}) where {RLO,RLS} = POMDPCommonRLEnv{RLO,RLS}(env.m, env.s, env.o)
RL.render(env::POMDPCommonRLEnv) = render(env.m, (sp=env.s, o=env.o))
RL.state(env::POMDPCommonRLEnv) = (env.s, env.o)
RL.state(env::POMDPCommonRLEnv{RLO,RLS}) where {RLO,RLS} = convert_s(RLS, env.s, env.m)
RL.valid_actions(env::POMDPCommonRLEnv) = actions(env.m, env.s)

RL.observations(env::POMDPCommonRLEnv{RLO}) where {RLO} = (convert_o(RLO, o, env.m) for o in observations(env.m)) # should really be some kind of lazy map that handles uncountably infinite spaces
RL.provided(::typeof(RL.observations), ::Type{<:Tuple{POMDPCommonRLEnv{<:Any, M, <:Any, <:Any}}}) where {M} = static_hasmethod(observations, Tuple{<:M})
RL.provided(::typeof(RL.observations), ::Type{<:Tuple{POMDPCommonRLEnv{<:Any,<:Any,M,<:Any,<:Any}}}) where {M} = static_hasmethod(observations, Tuple{<:M})

function RL.setstate!(env::POMDPCommonRLEnv, so)
env.s = first(so)
env.o = last(so)
RL.@provide function RL.setstate!(env::POMDPCommonRLEnv{<:Any,<:Any,<:Any,S}, s) where {S}
env.s = convert_s(S, s, env.m)
env.o = rand(initialobs(env.m, env.s))
return nothing
end

Expand Down

0 comments on commit ec1d1cf

Please sign in to comment.