From ec1d1cf7ebcd90fa42312bcf320f125e716eaa33 Mon Sep 17 00:00:00 2001 From: Johannes Fischer Date: Wed, 23 Aug 2023 16:17:13 +0200 Subject: [PATCH] Change state representation in POMDPCommonRLEnv Add RLS parameter for optional state type conversion. Return only converted state instead of state-observation tuple in `RL.state()` --- .../src/CommonRLIntegration/to_env.jl | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/lib/POMDPTools/src/CommonRLIntegration/to_env.jl b/lib/POMDPTools/src/CommonRLIntegration/to_env.jl index c9c39292..a5ea5045 100644 --- a/lib/POMDPTools/src/CommonRLIntegration/to_env.jl +++ b/lib/POMDPTools/src/CommonRLIntegration/to_env.jl @@ -13,7 +13,7 @@ end """ MDPCommonRLEnv(m, [s]) MDPCommonRLEnv{RLO}(m, [s]) - + Create a CommonRLInterface environment from MDP m; optionally specify the state 's'. The `RLO` parameter can be used to specify a type to convert the observation to. By default, this is `AbstractArray`. Use `Any` to disable conversion. @@ -47,7 +47,7 @@ function RL.setstate!(env::MDPCommonRLEnv{<:Any, <:Any, S}, s) where S return nothing end -mutable struct POMDPCommonRLEnv{RLO, M<:POMDP, S, O} <: AbstractPOMDPsCommonRLEnv +mutable struct POMDPCommonRLEnv{RLO, RLS, M<:POMDP, S, O} <: AbstractPOMDPsCommonRLEnv m::M s::S o::O @@ -56,13 +56,14 @@ end """ POMDPCommonRLEnv(m, [s], [o]) POMDPCommonRLEnv{RLO}(m, [s], [o]) - + Create a CommonRLInterface environment from POMDP m; optionally specify the state 's' and observation 'o'. -The `RLO` parameter can be used to specify a type to convert the observation to. By default, this is `AbstractArray`. Use `Any` to disable conversion. +The `RLO` and `RLS` parameters can be used to specify types to convert the observation and state to. By default, this is `AbstractArray`. Use `Any` to disable conversion. """ -POMDPCommonRLEnv{RLO}(m, s=rand(initialstate(m)), o=rand(initialobs(m, s))) where {RLO} = POMDPCommonRLEnv{RLO, typeof(m), statetype(m), obstype(m)}(m, s, o) -POMDPCommonRLEnv(m, s=rand(initialstate(m)), o=rand(initialobs(m, s))) = POMDPCommonRLEnv{AbstractArray}(m, s, o) +POMDPCommonRLEnv{RLO,RLS}(m, s=rand(initialstate(m)), o=rand(initialobs(m, s))) where {RLO,RLS} = POMDPCommonRLEnv{RLO,RLS,typeof(m),statetype(m),obstype(m)}(m, s, o) +POMDPCommonRLEnv{RLO}(m, s=rand(initialstate(m)), o=rand(initialobs(m, s))) where {RLO} = POMDPCommonRLEnv{RLO,AbstractArray,typeof(m),statetype(m),obstype(m)}(m, s, o) +POMDPCommonRLEnv(m, s=rand(initialstate(m)), o=rand(initialobs(m, s))) = POMDPCommonRLEnv{AbstractArray,AbstractArray}(m, s, o) function RL.reset!(env::POMDPCommonRLEnv) env.s = rand(initialstate(env.m)) @@ -79,17 +80,17 @@ end RL.observe(env::POMDPCommonRLEnv{RLO}) where {RLO} = convert_o(RLO, env.o, env.m) -RL.clone(env::POMDPCommonRLEnv{RLO}) where {RLO} = POMDPCommonRLEnv{RLO}(env.m, env.s, env.o) +RL.clone(env::POMDPCommonRLEnv{RLO,RLS}) where {RLO,RLS} = POMDPCommonRLEnv{RLO,RLS}(env.m, env.s, env.o) RL.render(env::POMDPCommonRLEnv) = render(env.m, (sp=env.s, o=env.o)) -RL.state(env::POMDPCommonRLEnv) = (env.s, env.o) +RL.state(env::POMDPCommonRLEnv{RLO,RLS}) where {RLO,RLS} = convert_s(RLS, env.s, env.m) RL.valid_actions(env::POMDPCommonRLEnv) = actions(env.m, env.s) RL.observations(env::POMDPCommonRLEnv{RLO}) where {RLO} = (convert_o(RLO, o, env.m) for o in observations(env.m)) # should really be some kind of lazy map that handles uncountably infinite spaces -RL.provided(::typeof(RL.observations), ::Type{<:Tuple{POMDPCommonRLEnv{<:Any, M, <:Any, <:Any}}}) where {M} = static_hasmethod(observations, Tuple{<:M}) +RL.provided(::typeof(RL.observations), ::Type{<:Tuple{POMDPCommonRLEnv{<:Any,<:Any,M,<:Any,<:Any}}}) where {M} = static_hasmethod(observations, Tuple{<:M}) -function RL.setstate!(env::POMDPCommonRLEnv, so) - env.s = first(so) - env.o = last(so) +RL.@provide function RL.setstate!(env::POMDPCommonRLEnv{<:Any,<:Any,<:Any,S}, s) where {S} + env.s = convert_s(S, s, env.m) + env.o = rand(initialobs(env.m, env.s)) return nothing end