Skip to content

Commit

Permalink
updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed Jun 24, 2024
1 parent 05052df commit 97383f3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/envs/circular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,13 @@ function POMDPs.observations(pomdp::CircularMaze)
# NOTE: In JuliaPOMDPs, an observation space is NOT the set of possible distributions, but rather union of the support of all possible observations
corridors = 1:pomdp.n_corridors # from CMAZE_SENSE_CORRIDOR
perms = permutations(pomdp.probabilities)
space = chain(corridors, perms) # generator
space = Iterators.flatten(corridors, perms) # generator
return space
end

# TODO: maybe implement POMDPs.obsindex

# TODO: confirm that transitions are non-Deterministic
function POMDPs.transition(pomdp::CircularMaze, s::CircularMazeState, a::Integer)
@assert a in actions(pomdp) "Unrecognized action $a"
if a == CMAZE_DECLARE_GOAL
Expand Down
12 changes: 12 additions & 0 deletions src/updaters/kalman.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import POMDPs

struct HistoryUpdater <: POMDPs.Updater end

POMDPs.initialize_belief(up::HistoryUpdater, d) = Any[d]

function POMDPs.update(up::HistoryUpdater, b, a, o)
bp = copy(b)
push!(bp, a)
push!(bp, o)
return bp
end
2 changes: 1 addition & 1 deletion test/circular_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
let
pomdp = CircularMaze()
pomdp = CircularMaze(2, 5, 0.99)

policy = RandomPolicy(pomdp, rng=MersenneTwister(2))
sim = RolloutSimulator(rng=MersenneTwister(3), max_steps=100)
Expand Down

0 comments on commit 97383f3

Please sign in to comment.