diff --git a/lib/POMDPTools/src/CommonRLIntegration/CommonRLIntegration.jl b/lib/POMDPTools/src/CommonRLIntegration/CommonRLIntegration.jl index 70c77b2b..b2fd39ee 100644 --- a/lib/POMDPTools/src/CommonRLIntegration/CommonRLIntegration.jl +++ b/lib/POMDPTools/src/CommonRLIntegration/CommonRLIntegration.jl @@ -10,14 +10,15 @@ using Tricks: static_hasmethod export MDPCommonRLEnv, - POMDPCommonRLEnv -include("to_env.jl") + POMDPCommonRLEnv, + POMDPsCommonRLEnv +include("to_env.jl") export RLEnvMDP, RLEnvPOMDP, OpaqueRLEnvMDP, OpaqueRLEnvPOMDP -include("from_env.jl") +include("from_env.jl") end diff --git a/lib/POMDPTools/src/CommonRLIntegration/to_env.jl b/lib/POMDPTools/src/CommonRLIntegration/to_env.jl index a5ea5045..61b6d737 100644 --- a/lib/POMDPTools/src/CommonRLIntegration/to_env.jl +++ b/lib/POMDPTools/src/CommonRLIntegration/to_env.jl @@ -97,5 +97,14 @@ end Base.convert(::Type{RL.AbstractEnv}, m::POMDP) = POMDPCommonRLEnv(m) Base.convert(::Type{RL.AbstractEnv}, m::MDP) = MDPCommonRLEnv(m) +POMDPsCommonRLEnv(m::POMDP, s) = POMDPCommonRLEnv(m, s) +POMDPsCommonRLEnv(m::MDP, s) = MDPCommonRLEnv(m, s) + Base.convert(::Type{MDP}, env::MDPCommonRLEnv) = env.m Base.convert(::Type{POMDP}, env::POMDPCommonRLEnv) = env.m + +POMDPs.convert_s(::Type{Any}, s::S, problem::Union{MDP{S},POMDP{S}}) where {S} = s +POMDPs.convert_s(::Type{S}, s, problem::Union{MDP{S},POMDP{S}}) where {S} = convert(S, s) + +POMDPs.convert_o(::Type{Any}, o::O, problem::POMDP{<:Any,<:Any,O}) where {O} = o +POMDPs.convert_o(::Type{O}, o, problem::POMDP{<:Any,<:Any,O}) where {O} = convert(O, o)