From eeed769a667f98133b6409dd09b53c23a6beb2a5 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Wed, 10 Jul 2024 17:18:48 -0600 Subject: [PATCH] added new handling of terminal states in GenerativeBelief MDP --- lib/POMDPTools/src/ModelTools/ModelTools.jl | 9 ++- lib/POMDPTools/src/ModelTools/deprecated.jl | 1 + .../src/ModelTools/generative_belief_mdp.jl | 73 +++++++++++++------ .../model_tools/test_generative_belief_mdp.jl | 66 ++++++++++++++++- 4 files changed, 122 insertions(+), 27 deletions(-) create mode 100644 lib/POMDPTools/src/ModelTools/deprecated.jl diff --git a/lib/POMDPTools/src/ModelTools/ModelTools.jl b/lib/POMDPTools/src/ModelTools/ModelTools.jl index f8d6340d..2524f352 100644 --- a/lib/POMDPTools/src/ModelTools/ModelTools.jl +++ b/lib/POMDPTools/src/ModelTools/ModelTools.jl @@ -40,7 +40,10 @@ export terminalstate include("terminal_state.jl") -export GenerativeBeliefMDP +export GenerativeBeliefMDP, + DefaultGBMDPTerminalBehavior, + ContinueTerminalBehavior, + TerminalStateTerminalBehavior include("generative_belief_mdp.jl") export FullyObservablePOMDP @@ -78,4 +81,8 @@ export reward_vectors include("matrices.jl") +export + gbmdp_handle_terminal +include("deprecated.jl") + end diff --git a/lib/POMDPTools/src/ModelTools/deprecated.jl b/lib/POMDPTools/src/ModelTools/deprecated.jl new file mode 100644 index 00000000..f2d08480 --- /dev/null +++ b/lib/POMDPTools/src/ModelTools/deprecated.jl @@ -0,0 +1 @@ +gbmdp_handle_terminal(pomdp, updater, b, s, a, rng) = nothing diff --git a/lib/POMDPTools/src/ModelTools/generative_belief_mdp.jl b/lib/POMDPTools/src/ModelTools/generative_belief_mdp.jl index 0b11f517..4e37ae4c 100644 --- a/lib/POMDPTools/src/ModelTools/generative_belief_mdp.jl +++ b/lib/POMDPTools/src/ModelTools/generative_belief_mdp.jl @@ -1,26 +1,36 @@ """ GenerativeBeliefMDP(pomdp, updater) + GenerativeBeliefMDP(pomdp, updater, terminal_behavior) Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`. """ -struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, B, A} <: MDP{B, A} +struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, T, B, A} <: MDP{B, A} pomdp::P updater::U + terminal_behavior::T end -function GenerativeBeliefMDP(pomdp::P, up::U) where {P<:POMDP, U<:Updater} - # XXX hack to determine belief type - b0 = initialize_belief(up, initialstate(pomdp)) - GenerativeBeliefMDP{P, U, typeof(b0), actiontype(pomdp)}(pomdp, up) +function GenerativeBeliefMDP(pomdp, updater; terminal_behavior=DefaultGBMDPTerminalBehavior(pomdp, updater)) + B = determine_gbmdp_state_type(pomdp, updater, terminal_behavior) + GenerativeBeliefMDP{typeof(pomdp), + typeof(updater), + typeof(terminal_behavior), + B, + actiontype(pomdp) + }(pomdp, updater, terminal_behavior) +end + +function initialstate(bmdp::GenerativeBeliefMDP) + return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp))) end function POMDPs.gen(bmdp::GenerativeBeliefMDP, b, a, rng::AbstractRNG) s = rand(rng, b) if isterminal(bmdp.pomdp, s) - bp = gbmdp_handle_terminal(bmdp.pomdp, bmdp.updater, b, s, a, rng::AbstractRNG)::typeof(b) + bp = bmdp.terminal_behavior(b, s, a, rng) return (sp=bp, r=0.0) end - sp, o, r = @gen(:sp, :o, :r)(bmdp.pomdp, s, a, rng) # maybe this should have been generate_or? + o, r = @gen(:o, :r)(bmdp.pomdp, s, a, rng) bp = update(bmdp.updater, b, a, o) return (sp=bp, r=r) end @@ -28,28 +38,47 @@ end actions(bmdp::GenerativeBeliefMDP{P,U,B,A}, b::B) where {P,U,B,A} = actions(bmdp.pomdp, b) actions(bmdp::GenerativeBeliefMDP) = actions(bmdp.pomdp) -isterminal(bmdp::GenerativeBeliefMDP, b) = all(isterminal(bmdp.pomdp, s) for s in support(b)) +isterminal(bmdp::GenerativeBeliefMDP, b) = all(s -> isterminal(bmdp.pomdp, s) || pdf(b, s) == 0.0, support(b)) +isterminal(bmdp::GenerativeBeliefMDP, ts::TerminalState) = true discount(bmdp::GenerativeBeliefMDP) = discount(bmdp.pomdp) -# override this if you want to handle it in a special way -function gbmdp_handle_terminal(pomdp::POMDP, updater::Updater, b, s, a, rng) - @warn(""" - Sampled a terminal state for a GenerativeBeliefMDP transition - not sure how to proceed, but will try. +function determine_gbmdp_state_type(pomdp, updater) + b0 = initialize_belief(updater, initialstate(pomdp)) + return typeof(b0) +end - See $(@__FILE__) and implement a new method of POMDPToolbox.gbmdp_handle_terminal if you want special behavior in this case. +determine_gbmdp_state_type(pomdp, updater, terminal_behavior) = determine_gbmdp_state_type(pomdp, updater) - """, maxlog=1) - sp, o, r = @gen(:sp, :o, :r)(pomdp, s, a, rng) - bp = update(updater, b, a, o) - return bp +struct DefaultGBMDPTerminalBehavior{M, U} + pomdp::M + updater::U end -function initialstate(bmdp::GenerativeBeliefMDP) - return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp))) +function (tb::DefaultGBMDPTerminalBehavior)(b, s, a, rng) + + # This code block is only to handle backwards compatibility for the deprecated gbmdp_handle_terminal function + bp = gbmdp_handle_terminal(tb.pomdp, tb.updater, b, s, a, rng) + if bp != nothing # user has implemented gbmdp_handle_terminal + Base.depwarn("Using gbmdp_handle_terminal to specify terminal behavior for a GenerativeBeliefMDP is deprecated. Use the terminal_behavior keyword argument instead.", :gbmdp_handle_terminal) + return bp + end + + return TerminalStateTerminalBehavior()(b, s, a, rng) +end + +determine_gbmdp_state_type(pomdp, updater, tb::DefaultGBMDPTerminalBehavior) = determine_gbmdp_state_type(pomdp, updater, TerminalStateTerminalBehavior()) + +struct ContinueTerminalBehavior{M, U} + pomdp::M + updater::U end -# deprecated in POMDPs v0.9 -function initialstate(bmdp::GenerativeBeliefMDP, rng::AbstractRNG) - return initialize_belief(bmdp.updater, initialstate(bmdp.pomdp)) +function (tb::ContinueTerminalBehavior)(b, s, a, rng) + o, r = @gen(:o, :r)(tb.pomdp, s, a, rng) + return update(tb.updater, b, a, o) end + +struct TerminalStateTerminalBehavior end +(tb::TerminalStateTerminalBehavior)(args...) = terminalstate +determine_gbmdp_state_type(pomdp, updater, tb::TerminalStateTerminalBehavior) = promote_type(determine_gbmdp_state_type(pomdp, updater), TerminalState) diff --git a/lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl b/lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl index d02ce0c0..2c3ac647 100644 --- a/lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl +++ b/lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl @@ -1,7 +1,65 @@ let - pomdp = BabyPOMDP() - up = updater(pomdp) + @testset "GenerativeBeliefMDP" begin + @testset "Baby" begin + pomdp = BabyPOMDP() + up = updater(pomdp) - bmdp = GenerativeBeliefMDP(pomdp, up) - b = initialstate(bmdp, Random.default_rng()) + bmdp = GenerativeBeliefMDP(pomdp, up) + b = rand(initialstate(bmdp)) + @test rand(b) isa statetype(pomdp) + + @test simulate(RolloutSimulator(max_steps=10), bmdp, RandomPolicy(bmdp)) <= 0 + end + + terminal_test_m = QuickPOMDP( + states = 1:2, + actions = 1:2, + observations = 1:2, + transition = (s, a) -> Deterministic(1), + observation = (a, sp) -> Deterministic(sp), + reward = s -> 0.0, + isterminal = ==(1), + initialstate = Deterministic(2) + ) + + @testset "Terminal Default" begin + up = DiscreteUpdater(terminal_test_m) + bm = GenerativeBeliefMDP(terminal_test_m, up) + + hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10)) + @test length(hist) == 1 + @test only(hist).s == DiscreteBelief(terminal_test_m, [0.0, 1.0]) + @test only(hist).sp == DiscreteBelief(terminal_test_m, [1.0, 0.0]) + @test !isterminal(bm, only(hist).s) + @test isterminal(bm, only(hist).sp) + end + + @testset "Terminal Uninformative Update" begin + struct UninformativeUpdater{M} <: Updater + m::M + end + + POMDPs.update(up::UninformativeUpdater, b, a, o) = Uniform(states(up.m)) + POMDPs.initialize_belief(up::UninformativeUpdater, d::Deterministic) = Uniform(rand(d)) + + up = UninformativeUpdater(terminal_test_m) + + # default terminal behavior + bm = GenerativeBeliefMDP(terminal_test_m, up) + hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp") + @test isterminal(bm, last(hist).sp) + + behavior = TerminalStateTerminalBehavior() + bm = GenerativeBeliefMDP(terminal_test_m, up) + hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp") + @test last(hist).sp === terminalstate + @test isterminal(bm, last(hist).sp) + + behavior = ContinueTerminalBehavior(terminal_test_m, up) + bm = GenerativeBeliefMDP(terminal_test_m, up, terminal_behavior=behavior) + hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10)) + @test length(hist) == 10 + end + + end end