Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* added test

* tried by converting in init_node

* fixed tests

* actually added test file
  • Loading branch information
zsunberg authored Aug 3, 2023
1 parent f46cc4d commit 3a93898
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 5 deletions.
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,23 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[compat]
BasicPOMCP = "0.3.8"
Colors = "0.12"
CommonRLSpaces = "0.2"
D3Trees = "0.3"
MCTS = "0.5"
POMDPs = "0.9"
POMDPTools = "0.1"
POMDPs = "0.9"
Parameters = "0.12"
ParticleFilters = "0.5"
StaticArrays = "1"
julia = "1.1"

[extras]
CommonRLSpaces = "408f5b3e-f2a2-48a6-b4bb-c8aa44c458e6"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Distributions", "POMDPModels", "Test"]
test = ["CommonRLSpaces", "Distributions", "LinearAlgebra", "POMDPModels", "StaticArrays", "Test"]
6 changes: 3 additions & 3 deletions src/beliefs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ struct POWNodeBelief{S,A,O,P}
dist::CategoricalVector{Tuple{S,Float64}}

POWNodeBelief{S,A,O,P}(m,a,o,d) where {S,A,O,P} = new(m,a,o,d)
function POWNodeBelief{S, A, O, P}(m::P, s::S, a::A, sp::S, o::O, r) where {S, A, O, P}
cv = CategoricalVector{Tuple{S,Float64}}((sp, convert(Float64, r)),
function POWNodeBelief{S, A, O, P}(m::P, s, a, sp, o, r) where {S, A, O, P}
cv = CategoricalVector{Tuple{S,Float64}}((convert(S, sp), convert(Float64, r)),
obs_weight(m, s, a, sp, o))
new(m, a, o, cv)
end
end

function POWNodeBelief(model::POMDP{S,A,O}, s::S, a::A, sp::S, o::O, r) where {S,A,O}
function POWNodeBelief(model::POMDP{S,A,O}, s, a, sp, o, r) where {S,A,O}
POWNodeBelief{S,A,O,typeof(model)}(model, s, a, sp, o, r)
end

Expand Down
20 changes: 20 additions & 0 deletions test/discussion_513.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using CommonRLSpaces
using Distributions
using LinearAlgebra
using StaticArrays

struct D513POMDP <: POMDP{SVector{3,Float64}, SVector{2,Float64}, SVector{3,Float64}} end

POMDPs.states(m::D513POMDP) = Box([-5,-5,-3], [5,5,3])
POMDPs.actions(m::D513POMDP) = Box([-5,-5], [5,5])
POMDPs.observations(m::D513POMDP) = Box([-5,-5,-3], [5,5,3])
POMDPs.transition(m::D513POMDP, s, a, dt=0.1) = MvNormal(s, Diagonal([0.1,0.1,0.1]))
POMDPs.observation(m::D513POMDP, s, a, sp) = MvNormal(sp, Diagonal([0.001,0.001,0.001]))
POMDPs.reward(m::D513POMDP, s, a, sp) = 0
POMDPs.discount(m::D513POMDP) = 0.9

m = D513POMDP()
solver = POMCPOWSolver()
policy = solve(solver, m)
a = action(policy, Deterministic([0.0,0.0,0.0]))
@test a in actions(m)
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,8 @@ using D3Trees
@testset "init_node_sr_belief_error" begin
include("init_node_sr_belief_error.jl")
end;

@testset "Discussion 513" begin
include("discussion_513.jl")
end
end;

0 comments on commit 3a93898

Please sign in to comment.