Skip to content

Commit

Permalink
added new handling of terminal states in GenerativeBelief MDP
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Jul 10, 2024
1 parent 5893e86 commit eeed769
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 27 deletions.
9 changes: 8 additions & 1 deletion lib/POMDPTools/src/ModelTools/ModelTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ export
terminalstate
include("terminal_state.jl")

export GenerativeBeliefMDP
export GenerativeBeliefMDP,
DefaultGBMDPTerminalBehavior,
ContinueTerminalBehavior,
TerminalStateTerminalBehavior
include("generative_belief_mdp.jl")

export FullyObservablePOMDP
Expand Down Expand Up @@ -78,4 +81,8 @@ export
reward_vectors
include("matrices.jl")

export
gbmdp_handle_terminal
include("deprecated.jl")

end
1 change: 1 addition & 0 deletions lib/POMDPTools/src/ModelTools/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
gbmdp_handle_terminal(pomdp, updater, b, s, a, rng) = nothing
73 changes: 51 additions & 22 deletions lib/POMDPTools/src/ModelTools/generative_belief_mdp.jl
Original file line number Diff line number Diff line change
@@ -1,55 +1,84 @@
"""
GenerativeBeliefMDP(pomdp, updater)
GenerativeBeliefMDP(pomdp, updater, terminal_behavior)
Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`.
"""
struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, B, A} <: MDP{B, A}
struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, T, B, A} <: MDP{B, A}
pomdp::P
updater::U
terminal_behavior::T
end

function GenerativeBeliefMDP(pomdp::P, up::U) where {P<:POMDP, U<:Updater}
# XXX hack to determine belief type
b0 = initialize_belief(up, initialstate(pomdp))
GenerativeBeliefMDP{P, U, typeof(b0), actiontype(pomdp)}(pomdp, up)
function GenerativeBeliefMDP(pomdp, updater; terminal_behavior=DefaultGBMDPTerminalBehavior(pomdp, updater))
B = determine_gbmdp_state_type(pomdp, updater, terminal_behavior)
GenerativeBeliefMDP{typeof(pomdp),
typeof(updater),
typeof(terminal_behavior),
B,
actiontype(pomdp)
}(pomdp, updater, terminal_behavior)
end

function initialstate(bmdp::GenerativeBeliefMDP)
return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp)))
end

function POMDPs.gen(bmdp::GenerativeBeliefMDP, b, a, rng::AbstractRNG)
s = rand(rng, b)
if isterminal(bmdp.pomdp, s)
bp = gbmdp_handle_terminal(bmdp.pomdp, bmdp.updater, b, s, a, rng::AbstractRNG)::typeof(b)
bp = bmdp.terminal_behavior(b, s, a, rng)
return (sp=bp, r=0.0)
end
sp, o, r = @gen(:sp, :o, :r)(bmdp.pomdp, s, a, rng) # maybe this should have been generate_or?
o, r = @gen(:o, :r)(bmdp.pomdp, s, a, rng)
bp = update(bmdp.updater, b, a, o)
return (sp=bp, r=r)
end

actions(bmdp::GenerativeBeliefMDP{P,U,B,A}, b::B) where {P,U,B,A} = actions(bmdp.pomdp, b)
actions(bmdp::GenerativeBeliefMDP) = actions(bmdp.pomdp)

isterminal(bmdp::GenerativeBeliefMDP, b) = all(isterminal(bmdp.pomdp, s) for s in support(b))
isterminal(bmdp::GenerativeBeliefMDP, b) = all(s -> isterminal(bmdp.pomdp, s) || pdf(b, s) == 0.0, support(b))
isterminal(bmdp::GenerativeBeliefMDP, ts::TerminalState) = true

discount(bmdp::GenerativeBeliefMDP) = discount(bmdp.pomdp)

# override this if you want to handle it in a special way
function gbmdp_handle_terminal(pomdp::POMDP, updater::Updater, b, s, a, rng)
@warn("""
Sampled a terminal state for a GenerativeBeliefMDP transition - not sure how to proceed, but will try.
function determine_gbmdp_state_type(pomdp, updater)
b0 = initialize_belief(updater, initialstate(pomdp))
return typeof(b0)
end

See $(@__FILE__) and implement a new method of POMDPToolbox.gbmdp_handle_terminal if you want special behavior in this case.
determine_gbmdp_state_type(pomdp, updater, terminal_behavior) = determine_gbmdp_state_type(pomdp, updater)

""", maxlog=1)
sp, o, r = @gen(:sp, :o, :r)(pomdp, s, a, rng)
bp = update(updater, b, a, o)
return bp
struct DefaultGBMDPTerminalBehavior{M, U}
pomdp::M
updater::U
end

function initialstate(bmdp::GenerativeBeliefMDP)
return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp)))
function (tb::DefaultGBMDPTerminalBehavior)(b, s, a, rng)

# This code block is only to handle backwards compatibility for the deprecated gbmdp_handle_terminal function
bp = gbmdp_handle_terminal(tb.pomdp, tb.updater, b, s, a, rng)
if bp != nothing # user has implemented gbmdp_handle_terminal
Base.depwarn("Using gbmdp_handle_terminal to specify terminal behavior for a GenerativeBeliefMDP is deprecated. Use the terminal_behavior keyword argument instead.", :gbmdp_handle_terminal)
return bp
end

return TerminalStateTerminalBehavior()(b, s, a, rng)
end

determine_gbmdp_state_type(pomdp, updater, tb::DefaultGBMDPTerminalBehavior) = determine_gbmdp_state_type(pomdp, updater, TerminalStateTerminalBehavior())

struct ContinueTerminalBehavior{M, U}
pomdp::M
updater::U
end

# deprecated in POMDPs v0.9
function initialstate(bmdp::GenerativeBeliefMDP, rng::AbstractRNG)
return initialize_belief(bmdp.updater, initialstate(bmdp.pomdp))
function (tb::ContinueTerminalBehavior)(b, s, a, rng)
o, r = @gen(:o, :r)(tb.pomdp, s, a, rng)
return update(tb.updater, b, a, o)
end

struct TerminalStateTerminalBehavior end
(tb::TerminalStateTerminalBehavior)(args...) = terminalstate
determine_gbmdp_state_type(pomdp, updater, tb::TerminalStateTerminalBehavior) = promote_type(determine_gbmdp_state_type(pomdp, updater), TerminalState)
66 changes: 62 additions & 4 deletions lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,65 @@
let
pomdp = BabyPOMDP()
up = updater(pomdp)
@testset "GenerativeBeliefMDP" begin
@testset "Baby" begin
pomdp = BabyPOMDP()
up = updater(pomdp)

bmdp = GenerativeBeliefMDP(pomdp, up)
b = initialstate(bmdp, Random.default_rng())
bmdp = GenerativeBeliefMDP(pomdp, up)
b = rand(initialstate(bmdp))
@test rand(b) isa statetype(pomdp)

@test simulate(RolloutSimulator(max_steps=10), bmdp, RandomPolicy(bmdp)) <= 0
end

terminal_test_m = QuickPOMDP(
states = 1:2,
actions = 1:2,
observations = 1:2,
transition = (s, a) -> Deterministic(1),
observation = (a, sp) -> Deterministic(sp),
reward = s -> 0.0,
isterminal = ==(1),
initialstate = Deterministic(2)
)

@testset "Terminal Default" begin
up = DiscreteUpdater(terminal_test_m)
bm = GenerativeBeliefMDP(terminal_test_m, up)

hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10))
@test length(hist) == 1
@test only(hist).s == DiscreteBelief(terminal_test_m, [0.0, 1.0])
@test only(hist).sp == DiscreteBelief(terminal_test_m, [1.0, 0.0])
@test !isterminal(bm, only(hist).s)
@test isterminal(bm, only(hist).sp)
end

@testset "Terminal Uninformative Update" begin
struct UninformativeUpdater{M} <: Updater
m::M
end

POMDPs.update(up::UninformativeUpdater, b, a, o) = Uniform(states(up.m))
POMDPs.initialize_belief(up::UninformativeUpdater, d::Deterministic) = Uniform(rand(d))

up = UninformativeUpdater(terminal_test_m)

# default terminal behavior
bm = GenerativeBeliefMDP(terminal_test_m, up)
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp")
@test isterminal(bm, last(hist).sp)

behavior = TerminalStateTerminalBehavior()
bm = GenerativeBeliefMDP(terminal_test_m, up)
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp")
@test last(hist).sp === terminalstate
@test isterminal(bm, last(hist).sp)

behavior = ContinueTerminalBehavior(terminal_test_m, up)
bm = GenerativeBeliefMDP(terminal_test_m, up, terminal_behavior=behavior)
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10))
@test length(hist) == 10
end

end
end

0 comments on commit eeed769

Please sign in to comment.