-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f733630
commit 79c0726
Showing
11 changed files
with
143 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,8 @@ | ||
""" | ||
Base type for an MDP/POMDP belief compression. | ||
""" | ||
abstract type Compressor end | ||
|
||
|
||
""" | ||
fit!(compressor::Compressor, beliefs) | ||
Fit the compressor to beliefs. | ||
""" | ||
function fit! end | ||
|
||
|
||
""" | ||
compress(compressor::Compressor, beliefs) | ||
Compress the sampled beliefs using method associated with compressor, and returns a compressed representation. | ||
""" | ||
function compress end | ||
|
||
|
||
""" | ||
decompress(compressor::Compressor, compressed) | ||
Decompress the compressed beliefs using method associated with compressor, and returns the reconstructed beliefs. | ||
""" | ||
function decompress end | ||
|
||
# TODO: remove decompress and make compress a functor (https://docs.julialang.org/en/v1/manual/methods/#Note-on-Optional-and-keyword-Arguments) | ||
function fit! end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
""" | ||
Wrapper for ManifoldLearning.jl. See https://wildart.github.io/ManifoldLearning.jl/stable/. | ||
""" | ||
|
||
# TODO |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,30 @@ | ||
""" | ||
Wrappers for MultivariateStats.jl. See https://juliastats.org/MultivariateStats.jl/stable/. | ||
""" | ||
|
||
using MultivariateStats | ||
|
||
mutable struct MultivariateStatsCompressor{T<:MultivariateStats.AbstractDimensionalityReduction} <: Compressor | ||
|
||
mutable struct MVSCompressor{T<:MultivariateStats.AbstractDimensionalityReduction} <: Compressor | ||
const maxoutdim::Integer | ||
M # TODO: check if this is Julian (how to replace unde) | ||
M | ||
end | ||
|
||
function fit!(compressor::MultivariateStatsCompressor{T}, beliefs) where T<:MultivariateStats.AbstractDimensionalityReduction | ||
compressor.M = MultivariateStats.fit(T, beliefs'; maxoutdim=compressor.maxoutdim) | ||
end | ||
(c::MVSCompressor)(beliefs) = ndims(beliefs) == 2 ? MultivariateStats.predict(c.M, beliefs')' : vec(MultivariateStats.predict(c.M, beliefs)) | ||
|
||
# TODO: is there a way to solve this w/ multiple dispatch? clean up | ||
function compress(compressor::MultivariateStatsCompressor, beliefs) | ||
# TODO: is there better way to do this? | ||
return ndims(beliefs) == 2 ? predict(compressor.M, beliefs')' : vec(predict(compressor.M, beliefs)) | ||
function fit!(compressor::MVSCompressor{T}, beliefs) where T<:MultivariateStats.AbstractDimensionalityReduction | ||
compressor.M = MultivariateStats.fit(T, beliefs'; maxoutdim=compressor.maxoutdim) | ||
end | ||
|
||
decompress(compressor::MultivariateStatsCompressor, compressed) = MultivariateStats.reconstruct(compressor.M, compressed) | ||
|
||
MultivariateStatsCompressor(maxoutdim::Integer, T) = MultivariateStatsCompressor{T}(maxoutdim, nothing) | ||
MVSCompressor(maxoutdim::Integer, T) = MVSCompressor{T}(maxoutdim, nothing) | ||
|
||
# PCA Compressors | ||
PCACompressor(maxoutdim::Integer) = MultivariateStatsCompressor(maxoutdim, PCA) | ||
KernelPCACompressor(maxoutdim::Integer) = MultivariateStatsCompressor(maxoutdim, KernelPCA) | ||
PPCACompressor(maxoutdim::Integer) = MultivariateStatsCompressor(maxoutdim, PPCA) | ||
|
||
# TODO: debug this | ||
function fit!(compressor::MultivariateStatsCompressor{KernelPCA}, beliefs) | ||
compressor.M = MultivariateStats.fit(KernelPCA, beliefs'; maxoutdim=compressor.maxoutdim, inverse=true) | ||
end | ||
PCACompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, PCA) | ||
KernelPCACompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, KernelPCA) | ||
PPCACompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, PPCA) | ||
|
||
# Factor Analysis Compressor | ||
FactorAnalysisCompressor(maxoutdim::Integer) = MultivariateStatsCompressor(maxoutdim, FactorAnalysis) | ||
FactorAnalysisCompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, FactorAnalysis) | ||
|
||
# Multidimensional Scaling | ||
MDSCompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, MDS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,81 @@ | ||
struct CompressedBeliefSolver <: Solver | ||
explorer::Union{Policy, ExplorationPolicy} | ||
updater::Updater | ||
compressor::Compressor | ||
base_solver::Solver | ||
n::Integer | ||
### POLICY ### | ||
|
||
struct CompressedBeliefPolicy <: POMDPs.Policy | ||
m::CompressedBeliefMDP | ||
base_policy::Policy | ||
end | ||
|
||
function CompressedBeliefSolver( | ||
explorer::Union{Policy, ExplorationPolicy}, | ||
updater::Updater, | ||
compressor::Compressor, | ||
base_solver::Solver; | ||
n=100 | ||
) | ||
return CompressedBeliefSolver(explorer, updater, compressor, base_solver, n) | ||
function POMDPs.action(p::CompressedBeliefPolicy, s) | ||
b = initialize_belief(p.m.bmdp.updater, s) | ||
action(p.base_policy, encode(p.m, b)) | ||
end | ||
|
||
# TODO: make compressed solver that infers everything | ||
# TODO: make compressed solver that uses local FA solver | ||
function POMDPs.value(p::CompressedBeliefPolicy, s) | ||
b = initialize_belief(p.m.bmdp.updater, s) | ||
value(p.base_policy, encode(p.m, b)) | ||
end | ||
|
||
POMDPs.updater(p::CompressedBeliefPolicy) = p.m.bmdp.updater | ||
|
||
struct CompressedBeliefPolicy <: Policy | ||
### SOLVER ### | ||
|
||
struct CompressedBeliefSolver <: Solver | ||
m::CompressedBeliefMDP | ||
base_policy::Policy | ||
base_solver::Solver | ||
end | ||
|
||
POMDPs.action(p::CompressedBeliefPolicy, b) = action(p.base_policy, encode(m, b)) | ||
POMDPs.value(p::CompressedBeliefPolicy, b) = value(p.base_policy, encode(m, b)) | ||
POMDPs.updater(p::CompressedBeliefPolicy) = p.m.bmdp.updater | ||
# TODO: add seeding | ||
function CompressedBeliefSolver( | ||
pomdp::POMDP; | ||
explorer::Union{Policy, ExplorationPolicy}=RandomPolicy(pomdp), | ||
updater::Updater=applicable(POMDPs.states, pomdp) ? DiscreteUpdater(pomdp) : BootstrapFilter(pomdp, 5000), # hack to determine default updater, may select incompatible Updater | ||
compressor::Compressor=PCACompressor(1), | ||
n::Integer=50, # max number of belief samples to compress | ||
interp::Union{Nothing, LocalFunctionApproximator}=nothing, | ||
k=1, # k nearest neighbors; only used if interp is nothing | ||
verbose=false, | ||
max_iterations=1000, # for value iteration | ||
n_generative_samples=10, # number of steps to look ahead when calculated expected reward | ||
belres::Float64=1e-3, | ||
) | ||
# sample beliefs | ||
B = sample(pomdp, explorer, updater, n) | ||
|
||
function POMDPs.solve(solver::CompressedBeliefSolver, pomdp::POMDP) | ||
B = sample(pomdp, solver.explorer, solver.updater, solver.n) | ||
# compress beliefs and cache mapping | ||
B_numerical = mapreduce(b->convert_s(AbstractArray{Float64}, b, pomdp), hcat, B)' | ||
fit!(solver.compressor, B_numerical) | ||
B̃ = compress(solver.compressor, B_numerical) | ||
m = CompressedBeliefMDP(pomdp, solver.updater, solver.compressor) | ||
fit!(compressor, B_numerical) | ||
B̃ = compressor(B_numerical) | ||
ϕ = Dict(unique(t->t[2], zip(B, eachrow(B̃)))) | ||
merge!(m.ϕ, ϕ) # update compression cache | ||
base_policy = solve(solver.base_solver, m) | ||
return CompressedBeliefPolicy(m, base_policy) | ||
|
||
# construct the compressed belief-state MDP | ||
m = CompressedBeliefMDP(pomdp, updater, compressor) | ||
merge!(m.ϕ, ϕ) # update the compression cache | ||
|
||
# define the interpolator for the solver | ||
if isnothing(interp) | ||
data = map(row->SVector(row...), eachrow(B̃)) | ||
tree = KDTree(data) | ||
interp = LocalNNFunctionApproximator(tree, data, k) # TODO: check that we need this | ||
end | ||
|
||
# build the based solver | ||
base_solver = LocalApproximationValueIterationSolver( | ||
interp, | ||
max_iterations=max_iterations, | ||
belres=belres, | ||
verbose=verbose, | ||
is_mdp_generative=true, | ||
n_generative_samples=n_generative_samples | ||
) | ||
|
||
return CompressedBeliefSolver(m, base_solver) | ||
end | ||
|
||
function POMDPs.solve(solver::CompressedBeliefSolver, pomdp::POMDP) | ||
if solver.m.bmdp.pomdp !== pomdp | ||
@warn "Got $pomdp, but solver.m.bmdp.pomdp $(solver.m.bmdp.pomdp) isn't identical" | ||
end | ||
|
||
base_policy = solve(solver.base_solver, solver.m) | ||
return CompressedBeliefPolicy(solver.m, base_policy) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,25 @@ | ||
function test_compressor(C::Function, maxoutdim::Int) | ||
pomdp = TMaze(20, 0.99) | ||
sampler = DiscreteRandomSampler(pomdp) | ||
pomdp = BabyPOMDP() # TODO: change to TMaze once I figure out how to properly sample | ||
compressor = C(maxoutdim) | ||
solver = CompressedSolver(pomdp, sampler, compressor; n_samples=20, verbose=false, max_iterations=5) | ||
approx_policy = solve(solver, pomdp) | ||
solver = CompressedBeliefSolver(pomdp; compressor=compressor, n=20) | ||
policy = solve(solver, pomdp) | ||
s = initialstate(pomdp) | ||
_ = value(approx_policy, s) | ||
_ = action(approx_policy, s) | ||
return approx_policy | ||
_ = action(policy, s) | ||
_ = value(policy, s) | ||
return policy | ||
end | ||
|
||
MV_STATS_COMPRESSORS = ( | ||
PCACompressor, | ||
KernelPCACompressor, | ||
PPCACompressor, | ||
# FactorAnalysisCompressor | ||
FactorAnalysisCompressor, | ||
MDSCompressor | ||
) | ||
|
||
@testset "Compressor Tests" begin | ||
@testset "$C" for C in MV_STATS_COMPRESSORS | ||
@inferred test_compressor(C, 1) | ||
@inferred test_compressor(C, 10) | ||
@test_nowarn test_compressor(C, 1) | ||
@test_nowarn test_compressor(C, 2) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
using CompressedBeliefMDPs | ||
using Test | ||
|
||
using POMDPs | ||
using POMDPModels | ||
using POMDPs, POMDPModels, POMDPTools | ||
# TODO: also test w/ FA solver | ||
using MCTS | ||
|
||
@testset "CompressedBeliefMDPs.jl" begin | ||
include("mv_stats_tests.jl") | ||
include("solver_tests.jl") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
@testset "Solver Tests" begin | ||
compressor = PCACompressor(1) | ||
@testset "$pomdp" for pomdp in (BabyPOMDP(), TigerPOMDP(), TMaze(6, 0.99), LightDark1D()) | ||
solver = CompressedBeliefSolver(pomdp; n=10) | ||
@test_nowarn test_solver(solver, pomdp) | ||
end | ||
end | ||
|
||
# TODO: add test w/ MCTS |