-
Notifications
You must be signed in to change notification settings - Fork 105
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added new handling of terminal states in GenerativeBelief MDP
- Loading branch information
Showing
4 changed files
with
122 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
gbmdp_handle_terminal(pomdp, updater, b, s, a, rng) = nothing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,84 @@ | ||
""" | ||
GenerativeBeliefMDP(pomdp, updater) | ||
GenerativeBeliefMDP(pomdp, updater, terminal_behavior) | ||
Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`. | ||
""" | ||
struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, B, A} <: MDP{B, A} | ||
struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, T, B, A} <: MDP{B, A} | ||
pomdp::P | ||
updater::U | ||
terminal_behavior::T | ||
end | ||
|
||
function GenerativeBeliefMDP(pomdp::P, up::U) where {P<:POMDP, U<:Updater} | ||
# XXX hack to determine belief type | ||
b0 = initialize_belief(up, initialstate(pomdp)) | ||
GenerativeBeliefMDP{P, U, typeof(b0), actiontype(pomdp)}(pomdp, up) | ||
function GenerativeBeliefMDP(pomdp, updater; terminal_behavior=DefaultGBMDPTerminalBehavior(pomdp, updater)) | ||
B = determine_gbmdp_state_type(pomdp, updater, terminal_behavior) | ||
GenerativeBeliefMDP{typeof(pomdp), | ||
typeof(updater), | ||
typeof(terminal_behavior), | ||
B, | ||
actiontype(pomdp) | ||
}(pomdp, updater, terminal_behavior) | ||
end | ||
|
||
function initialstate(bmdp::GenerativeBeliefMDP) | ||
return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp))) | ||
end | ||
|
||
function POMDPs.gen(bmdp::GenerativeBeliefMDP, b, a, rng::AbstractRNG) | ||
s = rand(rng, b) | ||
if isterminal(bmdp.pomdp, s) | ||
bp = gbmdp_handle_terminal(bmdp.pomdp, bmdp.updater, b, s, a, rng::AbstractRNG)::typeof(b) | ||
bp = bmdp.terminal_behavior(b, s, a, rng) | ||
return (sp=bp, r=0.0) | ||
end | ||
sp, o, r = @gen(:sp, :o, :r)(bmdp.pomdp, s, a, rng) # maybe this should have been generate_or? | ||
o, r = @gen(:o, :r)(bmdp.pomdp, s, a, rng) | ||
bp = update(bmdp.updater, b, a, o) | ||
return (sp=bp, r=r) | ||
end | ||
|
||
actions(bmdp::GenerativeBeliefMDP{P,U,B,A}, b::B) where {P,U,B,A} = actions(bmdp.pomdp, b) | ||
actions(bmdp::GenerativeBeliefMDP) = actions(bmdp.pomdp) | ||
|
||
isterminal(bmdp::GenerativeBeliefMDP, b) = all(isterminal(bmdp.pomdp, s) for s in support(b)) | ||
isterminal(bmdp::GenerativeBeliefMDP, b) = all(s -> isterminal(bmdp.pomdp, s) || pdf(b, s) == 0.0, support(b)) | ||
isterminal(bmdp::GenerativeBeliefMDP, ts::TerminalState) = true | ||
|
||
discount(bmdp::GenerativeBeliefMDP) = discount(bmdp.pomdp) | ||
|
||
# override this if you want to handle it in a special way | ||
function gbmdp_handle_terminal(pomdp::POMDP, updater::Updater, b, s, a, rng) | ||
@warn(""" | ||
Sampled a terminal state for a GenerativeBeliefMDP transition - not sure how to proceed, but will try. | ||
function determine_gbmdp_state_type(pomdp, updater) | ||
b0 = initialize_belief(updater, initialstate(pomdp)) | ||
return typeof(b0) | ||
end | ||
|
||
See $(@__FILE__) and implement a new method of POMDPToolbox.gbmdp_handle_terminal if you want special behavior in this case. | ||
determine_gbmdp_state_type(pomdp, updater, terminal_behavior) = determine_gbmdp_state_type(pomdp, updater) | ||
|
||
""", maxlog=1) | ||
sp, o, r = @gen(:sp, :o, :r)(pomdp, s, a, rng) | ||
bp = update(updater, b, a, o) | ||
return bp | ||
struct DefaultGBMDPTerminalBehavior{M, U} | ||
pomdp::M | ||
updater::U | ||
end | ||
|
||
function initialstate(bmdp::GenerativeBeliefMDP) | ||
return Deterministic(initialize_belief(bmdp.updater, initialstate(bmdp.pomdp))) | ||
function (tb::DefaultGBMDPTerminalBehavior)(b, s, a, rng) | ||
|
||
# This code block is only to handle backwards compatibility for the deprecated gbmdp_handle_terminal function | ||
bp = gbmdp_handle_terminal(tb.pomdp, tb.updater, b, s, a, rng) | ||
if bp != nothing # user has implemented gbmdp_handle_terminal | ||
Base.depwarn("Using gbmdp_handle_terminal to specify terminal behavior for a GenerativeBeliefMDP is deprecated. Use the terminal_behavior keyword argument instead.", :gbmdp_handle_terminal) | ||
return bp | ||
end | ||
|
||
return TerminalStateTerminalBehavior()(b, s, a, rng) | ||
end | ||
|
||
determine_gbmdp_state_type(pomdp, updater, tb::DefaultGBMDPTerminalBehavior) = determine_gbmdp_state_type(pomdp, updater, TerminalStateTerminalBehavior()) | ||
|
||
struct ContinueTerminalBehavior{M, U} | ||
pomdp::M | ||
updater::U | ||
end | ||
|
||
# deprecated in POMDPs v0.9 | ||
function initialstate(bmdp::GenerativeBeliefMDP, rng::AbstractRNG) | ||
return initialize_belief(bmdp.updater, initialstate(bmdp.pomdp)) | ||
function (tb::ContinueTerminalBehavior)(b, s, a, rng) | ||
o, r = @gen(:o, :r)(tb.pomdp, s, a, rng) | ||
return update(tb.updater, b, a, o) | ||
end | ||
|
||
struct TerminalStateTerminalBehavior end | ||
(tb::TerminalStateTerminalBehavior)(args...) = terminalstate | ||
determine_gbmdp_state_type(pomdp, updater, tb::TerminalStateTerminalBehavior) = promote_type(determine_gbmdp_state_type(pomdp, updater), TerminalState) |
66 changes: 62 additions & 4 deletions
66
lib/POMDPTools/test/model_tools/test_generative_belief_mdp.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,65 @@ | ||
let | ||
pomdp = BabyPOMDP() | ||
up = updater(pomdp) | ||
@testset "GenerativeBeliefMDP" begin | ||
@testset "Baby" begin | ||
pomdp = BabyPOMDP() | ||
up = updater(pomdp) | ||
|
||
bmdp = GenerativeBeliefMDP(pomdp, up) | ||
b = initialstate(bmdp, Random.default_rng()) | ||
bmdp = GenerativeBeliefMDP(pomdp, up) | ||
b = rand(initialstate(bmdp)) | ||
@test rand(b) isa statetype(pomdp) | ||
|
||
@test simulate(RolloutSimulator(max_steps=10), bmdp, RandomPolicy(bmdp)) <= 0 | ||
end | ||
|
||
terminal_test_m = QuickPOMDP( | ||
states = 1:2, | ||
actions = 1:2, | ||
observations = 1:2, | ||
transition = (s, a) -> Deterministic(1), | ||
observation = (a, sp) -> Deterministic(sp), | ||
reward = s -> 0.0, | ||
isterminal = ==(1), | ||
initialstate = Deterministic(2) | ||
) | ||
|
||
@testset "Terminal Default" begin | ||
up = DiscreteUpdater(terminal_test_m) | ||
bm = GenerativeBeliefMDP(terminal_test_m, up) | ||
|
||
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10)) | ||
@test length(hist) == 1 | ||
@test only(hist).s == DiscreteBelief(terminal_test_m, [0.0, 1.0]) | ||
@test only(hist).sp == DiscreteBelief(terminal_test_m, [1.0, 0.0]) | ||
@test !isterminal(bm, only(hist).s) | ||
@test isterminal(bm, only(hist).sp) | ||
end | ||
|
||
@testset "Terminal Uninformative Update" begin | ||
struct UninformativeUpdater{M} <: Updater | ||
m::M | ||
end | ||
|
||
POMDPs.update(up::UninformativeUpdater, b, a, o) = Uniform(states(up.m)) | ||
POMDPs.initialize_belief(up::UninformativeUpdater, d::Deterministic) = Uniform(rand(d)) | ||
|
||
up = UninformativeUpdater(terminal_test_m) | ||
|
||
# default terminal behavior | ||
bm = GenerativeBeliefMDP(terminal_test_m, up) | ||
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp") | ||
@test isterminal(bm, last(hist).sp) | ||
|
||
behavior = TerminalStateTerminalBehavior() | ||
bm = GenerativeBeliefMDP(terminal_test_m, up) | ||
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp") | ||
@test last(hist).sp === terminalstate | ||
@test isterminal(bm, last(hist).sp) | ||
|
||
behavior = ContinueTerminalBehavior(terminal_test_m, up) | ||
bm = GenerativeBeliefMDP(terminal_test_m, up, terminal_behavior=behavior) | ||
hist = collect(stepthrough(bm, RandomPolicy(bm), "s,sp", max_steps=10)) | ||
@test length(hist) == 10 | ||
end | ||
|
||
end | ||
end |