Skip to content

Commit

Permalink
Add convenience conversions in CommonRLEnv
Browse files Browse the repository at this point in the history
  • Loading branch information
johannes-fischer committed Aug 23, 2023
1 parent ec1d1cf commit b58909c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
7 changes: 4 additions & 3 deletions lib/POMDPTools/src/CommonRLIntegration/CommonRLIntegration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ using Tricks: static_hasmethod

export
MDPCommonRLEnv,
POMDPCommonRLEnv
include("to_env.jl")
POMDPCommonRLEnv,
POMDPsCommonRLEnv
include("to_env.jl")

export
RLEnvMDP,
RLEnvPOMDP,
OpaqueRLEnvMDP,
OpaqueRLEnvPOMDP
include("from_env.jl")
include("from_env.jl")

end
9 changes: 9 additions & 0 deletions lib/POMDPTools/src/CommonRLIntegration/to_env.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,14 @@ end
Base.convert(::Type{RL.AbstractEnv}, m::POMDP) = POMDPCommonRLEnv(m)
Base.convert(::Type{RL.AbstractEnv}, m::MDP) = MDPCommonRLEnv(m)

POMDPsCommonRLEnv(m::POMDP, s) = POMDPCommonRLEnv(m, s)
POMDPsCommonRLEnv(m::MDP, s) = MDPCommonRLEnv(m, s)

Base.convert(::Type{MDP}, env::MDPCommonRLEnv) = env.m
Base.convert(::Type{POMDP}, env::POMDPCommonRLEnv) = env.m

POMDPs.convert_s(::Type{Any}, s::S, problem::Union{MDP{S},POMDP{S}}) where {S} = s
POMDPs.convert_s(::Type{S}, s, problem::Union{MDP{S},POMDP{S}}) where {S} = convert(S, s)

POMDPs.convert_o(::Type{Any}, o::O, problem::POMDP{<:Any,<:Any,O}) where {O} = o
POMDPs.convert_o(::Type{O}, o, problem::POMDP{<:Any,<:Any,O}) where {O} = convert(O, o)

0 comments on commit b58909c

Please sign in to comment.