Skip to content

Commit

Permalink
major changes
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed Mar 26, 2024
1 parent f733630 commit 79c0726
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 98 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LocalApproximationValueIteration = "a40420fb-f401-52da-a663-f502e5b95060"
LocalFunctionApproximation = "db97f5ab-fc25-52dd-a8f9-02a257c35074"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MCTS = "e12ccd36-dcad-5f33-8774-9175229e7b33"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
ParticleFilters = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

Expand Down
8 changes: 5 additions & 3 deletions src/CompressedBeliefMDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Infiltrator

using POMDPs
using POMDPTools
using ParticleFilters
using LocalApproximationValueIteration
using LocalFunctionApproximation

Expand All @@ -26,8 +27,9 @@ export
MultivariateStatsCompressor,
PCACompressor,
KernelPCACompressor,
PPCACompressor
# FactorAnalysisCompressor # TODO: debug
PPCACompressor,
FactorAnalysisCompressor,
MDSCompressor
include("compressors/compressor.jl")
include("compressors/mv_stats.jl")

Expand All @@ -36,7 +38,7 @@ export
include("sampler.jl")

export
CompressedBeliefMDP,
CompressedBeliefMDP
include("cbmdp.jl")

export
Expand Down
24 changes: 15 additions & 9 deletions src/cbmdp.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
struct CompressedBeliefMDP{B, A} <: MDP{B, A}
bmdp::GenerativeBeliefMDP
compressor::Compressor
ϕ::Bijection # ϕ: belief ↦ compress(compressor, belief); NOTE: While compressions aren't usually injective, we cache compressed beliefs on a first-come, first-served basis, so the *cache* is effectively bijective.
ϕ::Bijection # ϕ: belief ↦ compressor(belief); NOTE: While compressions aren't usually injective, we cache compressed beliefs on a first-come, first-served basis, so the *cache* is effectively bijective.
end


Expand All @@ -10,7 +10,7 @@ function CompressedBeliefMDP(pomdp::POMDP, updater::Updater, compressor::Compres
# Hack to determine typeof(b̃)
bmdp = GenerativeBeliefMDP(pomdp, updater)
b = initialstate(bmdp).val
= compress(compressor, convert_s(AbstractVector{Float64}, b, bmdp.pomdp))
= compressor(convert_s(AbstractVector{Float64}, b, bmdp.pomdp))
B = typeof(b)
= typeof(b̃)
ϕ = Bijection{B, B̃}()
Expand All @@ -23,10 +23,14 @@ function decode(m::CompressedBeliefMDP, b̃)
end

function encode(m::CompressedBeliefMDP, b)
b = convert_s(AbstractVector{Float64}, b, m)
= get!(m.ϕ, b) do
b = convert_s(AbstractArray{Float64}, b, m) # TODO: not sure if I need a `let b = ...` here
compress(m.compressor, b) # NOTE: compress is only called if b ∉ domain(m.ϕ)
if b domain(m.ϕ)
= m.ϕ[b]
else
b_numerical = convert_s(AbstractArray{Float64}, b, m)
= m.compressor(b_numerical)
if image(m.ϕ)
m.ϕ[b] =
end
end
return
end
Expand All @@ -38,14 +42,17 @@ function POMDPs.gen(m::CompressedBeliefMDP, b̃, a, rng::Random.AbstractRNG)
return (sp=b̃p, r=r)
end

# TODO: handle sampling terminal states /Users/logan/.julia/packages/POMDPTools/7Rekv/src/ModelTools/generative_belief_mdp.jl

# TODO: use macro forwarding
# TODO: read about orthogonalized code on julia documetation
POMDPs.states(m::CompressedBeliefMDP) = [encode(m, initialize_belief(m.bmdp.updater, s)) for s in states(m.bmdp.pomdp)]
POMDPs.initialstate(m::CompressedBeliefMDP) = encode(m, initialstate(m.bmdp))
POMDPs.actions(m::CompressedBeliefMDP, b̃) = actions(m.bmdp, decode(m, b̃))
POMDPs.actions(m::CompressedBeliefMDP) = actions(m.bmdp)
POMDPs.actionindex(m::CompressedBeliefMDP, a) = actionindex(m.bmdp.pomdp, a)
POMDPs.isterminal(m::CompressedBeliefMDP, b̃) = isterminal(m.bmdp, decode(m, b̃))
POMDPs.discount(m::CompressedBeliefMDP) = discount(m.bmdp)
POMDPs.initialstate(m::CompressedBeliefMDP) = encode(m, initialstate(m.bmdp))
POMDPs.actionindex(m::CompressedBeliefMDP, a) = actionindex(m.bmdp.pomdp, a)

POMDPs.convert_s(t::Type, s, m::CompressedBeliefMDP) = convert_s(t, s, m.bmdp.pomdp)
POMDPs.convert_s(t::Type{<:AbstractArray}, s::AbstractArray, m::CompressedBeliefMDP) = convert_s(t, s, m.bmdp.pomdp) # NOTE: this second implementation is b/c to get around a requirement from POMDPLinter
Expand All @@ -54,7 +61,6 @@ POMDPs.convert_s(t::Type{<:AbstractArray}, s::AbstractArray, m::CompressedBelief
ExplicitDistribution = Union{SparseCat, BoolDistribution, Deterministic, Uniform} # distributions w/ explicit PDFs from POMDPs.jl (https://juliapomdp.github.io/POMDPs.jl/latest/POMDPTools/distributions/#Implemented-Distributions)
POMDPs.convert_s(::Type{<:AbstractArray}, s::ExplicitDistribution, m::POMDP) = [pdf(s, x) for x in states(m)]


# function POMDPs.convert_s(t::Type{V}, s, m::CompressedBeliefMDP) where V<:AbstractArray
# convert_s(t, s, m.bmdp.pomdp)
# end
Expand Down
24 changes: 1 addition & 23 deletions src/compressors/compressor.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,8 @@
"""
Base type for an MDP/POMDP belief compression.
"""
abstract type Compressor end


"""
fit!(compressor::Compressor, beliefs)
Fit the compressor to beliefs.
"""
function fit! end


"""
compress(compressor::Compressor, beliefs)
Compress the sampled beliefs using method associated with compressor, and returns a compressed representation.
"""
function compress end


"""
decompress(compressor::Compressor, compressed)
Decompress the compressed beliefs using method associated with compressor, and returns the reconstructed beliefs.
"""
function decompress end

# TODO: remove decompress and make compress a functor (https://docs.julialang.org/en/v1/manual/methods/#Note-on-Optional-and-keyword-Arguments)
function fit! end
5 changes: 5 additions & 0 deletions src/compressors/manifold_learning.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Wrapper for ManifoldLearning.jl. See https://wildart.github.io/ManifoldLearning.jl/stable/.
"""

# TODO
38 changes: 17 additions & 21 deletions src/compressors/mv_stats.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,30 @@
"""
Wrappers for MultivariateStats.jl. See https://juliastats.org/MultivariateStats.jl/stable/.
"""

using MultivariateStats

mutable struct MultivariateStatsCompressor{T<:MultivariateStats.AbstractDimensionalityReduction} <: Compressor

mutable struct MVSCompressor{T<:MultivariateStats.AbstractDimensionalityReduction} <: Compressor
const maxoutdim::Integer
M # TODO: check if this is Julian (how to replace unde)
M
end

function fit!(compressor::MultivariateStatsCompressor{T}, beliefs) where T<:MultivariateStats.AbstractDimensionalityReduction
compressor.M = MultivariateStats.fit(T, beliefs'; maxoutdim=compressor.maxoutdim)
end
(c::MVSCompressor)(beliefs) = ndims(beliefs) == 2 ? MultivariateStats.predict(c.M, beliefs')' : vec(MultivariateStats.predict(c.M, beliefs))

# TODO: is there a way to solve this w/ multiple dispatch? clean up
function compress(compressor::MultivariateStatsCompressor, beliefs)
# TODO: is there better way to do this?
return ndims(beliefs) == 2 ? predict(compressor.M, beliefs')' : vec(predict(compressor.M, beliefs))
function fit!(compressor::MVSCompressor{T}, beliefs) where T<:MultivariateStats.AbstractDimensionalityReduction
compressor.M = MultivariateStats.fit(T, beliefs'; maxoutdim=compressor.maxoutdim)
end

decompress(compressor::MultivariateStatsCompressor, compressed) = MultivariateStats.reconstruct(compressor.M, compressed)

MultivariateStatsCompressor(maxoutdim::Integer, T) = MultivariateStatsCompressor{T}(maxoutdim, nothing)
MVSCompressor(maxoutdim::Integer, T) = MVSCompressor{T}(maxoutdim, nothing)

# PCA Compressors
PCACompressor(maxoutdim::Integer) = MultivariateStatsCompressor(maxoutdim, PCA)
KernelPCACompressor(maxoutdim::Integer) = MultivariateStatsCompressor(maxoutdim, KernelPCA)
PPCACompressor(maxoutdim::Integer) = MultivariateStatsCompressor(maxoutdim, PPCA)

# TODO: debug this
function fit!(compressor::MultivariateStatsCompressor{KernelPCA}, beliefs)
compressor.M = MultivariateStats.fit(KernelPCA, beliefs'; maxoutdim=compressor.maxoutdim, inverse=true)
end
PCACompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, PCA)
KernelPCACompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, KernelPCA)
PPCACompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, PPCA)

# Factor Analysis Compressor
FactorAnalysisCompressor(maxoutdim::Integer) = MultivariateStatsCompressor(maxoutdim, FactorAnalysis)
FactorAnalysisCompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, FactorAnalysis)

# Multidimensional Scaling
MDSCompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, MDS)
7 changes: 6 additions & 1 deletion src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,9 @@ end

function sample(pomdp::POMDP, policy::ExplorationPolicy, updater::Updater, n::Integer)
# TODO:
end
end


"""
Adapted from algorithm 21.13 from AFDM
"""
98 changes: 69 additions & 29 deletions src/solver.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,81 @@
struct CompressedBeliefSolver <: Solver
explorer::Union{Policy, ExplorationPolicy}
updater::Updater
compressor::Compressor
base_solver::Solver
n::Integer
### POLICY ###

struct CompressedBeliefPolicy <: POMDPs.Policy
m::CompressedBeliefMDP
base_policy::Policy
end

function CompressedBeliefSolver(
explorer::Union{Policy, ExplorationPolicy},
updater::Updater,
compressor::Compressor,
base_solver::Solver;
n=100
)
return CompressedBeliefSolver(explorer, updater, compressor, base_solver, n)
function POMDPs.action(p::CompressedBeliefPolicy, s)
b = initialize_belief(p.m.bmdp.updater, s)
action(p.base_policy, encode(p.m, b))
end

# TODO: make compressed solver that infers everything
# TODO: make compressed solver that uses local FA solver
function POMDPs.value(p::CompressedBeliefPolicy, s)
b = initialize_belief(p.m.bmdp.updater, s)
value(p.base_policy, encode(p.m, b))
end

POMDPs.updater(p::CompressedBeliefPolicy) = p.m.bmdp.updater

struct CompressedBeliefPolicy <: Policy
### SOLVER ###

struct CompressedBeliefSolver <: Solver
m::CompressedBeliefMDP
base_policy::Policy
base_solver::Solver
end

POMDPs.action(p::CompressedBeliefPolicy, b) = action(p.base_policy, encode(m, b))
POMDPs.value(p::CompressedBeliefPolicy, b) = value(p.base_policy, encode(m, b))
POMDPs.updater(p::CompressedBeliefPolicy) = p.m.bmdp.updater
# TODO: add seeding
function CompressedBeliefSolver(
pomdp::POMDP;
explorer::Union{Policy, ExplorationPolicy}=RandomPolicy(pomdp),
updater::Updater=applicable(POMDPs.states, pomdp) ? DiscreteUpdater(pomdp) : BootstrapFilter(pomdp, 5000), # hack to determine default updater, may select incompatible Updater
compressor::Compressor=PCACompressor(1),
n::Integer=50, # max number of belief samples to compress
interp::Union{Nothing, LocalFunctionApproximator}=nothing,
k=1, # k nearest neighbors; only used if interp is nothing
verbose=false,
max_iterations=1000, # for value iteration
n_generative_samples=10, # number of steps to look ahead when calculated expected reward
belres::Float64=1e-3,
)
# sample beliefs
B = sample(pomdp, explorer, updater, n)

function POMDPs.solve(solver::CompressedBeliefSolver, pomdp::POMDP)
B = sample(pomdp, solver.explorer, solver.updater, solver.n)
# compress beliefs and cache mapping
B_numerical = mapreduce(b->convert_s(AbstractArray{Float64}, b, pomdp), hcat, B)'
fit!(solver.compressor, B_numerical)
= compress(solver.compressor, B_numerical)
m = CompressedBeliefMDP(pomdp, solver.updater, solver.compressor)
fit!(compressor, B_numerical)
= compressor(B_numerical)
ϕ = Dict(unique(t->t[2], zip(B, eachrow(B̃))))
merge!(m.ϕ, ϕ) # update compression cache
base_policy = solve(solver.base_solver, m)
return CompressedBeliefPolicy(m, base_policy)

# construct the compressed belief-state MDP
m = CompressedBeliefMDP(pomdp, updater, compressor)
merge!(m.ϕ, ϕ) # update the compression cache

# define the interpolator for the solver
if isnothing(interp)
data = map(row->SVector(row...), eachrow(B̃))
tree = KDTree(data)
interp = LocalNNFunctionApproximator(tree, data, k) # TODO: check that we need this
end

# build the based solver
base_solver = LocalApproximationValueIterationSolver(
interp,
max_iterations=max_iterations,
belres=belres,
verbose=verbose,
is_mdp_generative=true,
n_generative_samples=n_generative_samples
)

return CompressedBeliefSolver(m, base_solver)
end

function POMDPs.solve(solver::CompressedBeliefSolver, pomdp::POMDP)
if solver.m.bmdp.pomdp !== pomdp
@warn "Got $pomdp, but solver.m.bmdp.pomdp $(solver.m.bmdp.pomdp) isn't identical"
end

base_policy = solve(solver.base_solver, solver.m)
return CompressedBeliefPolicy(solver.m, base_policy)
end
20 changes: 10 additions & 10 deletions test/mv_stats_tests.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
function test_compressor(C::Function, maxoutdim::Int)
pomdp = TMaze(20, 0.99)
sampler = DiscreteRandomSampler(pomdp)
pomdp = BabyPOMDP() # TODO: change to TMaze once I figure out how to properly sample
compressor = C(maxoutdim)
solver = CompressedSolver(pomdp, sampler, compressor; n_samples=20, verbose=false, max_iterations=5)
approx_policy = solve(solver, pomdp)
solver = CompressedBeliefSolver(pomdp; compressor=compressor, n=20)
policy = solve(solver, pomdp)
s = initialstate(pomdp)
_ = value(approx_policy, s)
_ = action(approx_policy, s)
return approx_policy
_ = action(policy, s)
_ = value(policy, s)
return policy
end

MV_STATS_COMPRESSORS = (
PCACompressor,
KernelPCACompressor,
PPCACompressor,
# FactorAnalysisCompressor
FactorAnalysisCompressor,
MDSCompressor
)

@testset "Compressor Tests" begin
@testset "$C" for C in MV_STATS_COMPRESSORS
@inferred test_compressor(C, 1)
@inferred test_compressor(C, 10)
@test_nowarn test_compressor(C, 1)
@test_nowarn test_compressor(C, 2)
end
end
6 changes: 4 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using CompressedBeliefMDPs
using Test

using POMDPs
using POMDPModels
using POMDPs, POMDPModels, POMDPTools
# TODO: also test w/ FA solver
using MCTS

@testset "CompressedBeliefMDPs.jl" begin
include("mv_stats_tests.jl")
include("solver_tests.jl")
end
9 changes: 9 additions & 0 deletions test/solver_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testset "Solver Tests" begin
compressor = PCACompressor(1)
@testset "$pomdp" for pomdp in (BabyPOMDP(), TigerPOMDP(), TMaze(6, 0.99), LightDark1D())
solver = CompressedBeliefSolver(pomdp; n=10)
@test_nowarn test_solver(solver, pomdp)
end
end

# TODO: add test w/ MCTS

0 comments on commit 79c0726

Please sign in to comment.