Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prune Changes #19

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 106 additions & 45 deletions src/prune.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ end

function prune!(solver::SARSOPSolver, tree::SARSOPTree)
prune!(tree)
prune_strictly_dominated!(tree::SARSOPTree)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pruning strictly dominated alpha vecs at every iteration seems to take up a lot of time. Have you found performance to be worse when lumping it together with the conditional block containing prune_alpha!?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do remember considering this, but I don't remember the details when I compared the two. I can run some comparisons and report back. (timeline TBD)

Copy link
Member Author

@dylan-asmar dylan-asmar Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TL;DR; I'd recommend keeping it as pruning at every iteration.

Just ran the comparison for delta=1e-4. The allocations are a bit higher when running at every iteration, but the overall process is faster. The difference between these benchmarks and the original ones posted at the submission of the PR is due to the suggested changes (which help out quite a bit!).

BabyPOMDP

Settings:

  • epsilon: 0.1
  • precision: 0.001
  • delta: 0.0001
  • max_steps: 50 (for benchmarking)
  • max_time: 5.0 (for policy run)

Benchmark

At Every Iteration

BenchmarkTools.Trial: 3262 samples with 1 evaluation.
 Range (min  max):  1.384 ms  69.646 ms  ┊ GC (min  max): 0.00%  97.80%
 Time  (median):     1.423 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.531 ms ±  1.607 ms  ┊ GC (mean ± σ):  3.23% ±  6.40%

  █▆▅▄▂                                                      ▁
  ██████▆▆▅▅▅▅▃▁▃▁▁▃▁▁▁▁▁▁▃▄▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃ █
  1.38 ms      Histogram: log(frequency) by time     3.84 ms <

 Memory estimate: 662.32 KiB, allocs estimate: 8721.

Only when should_prune_alphas

BenchmarkTools.Trial: 3284 samples with 1 evaluation.
 Range (min  max):  1.420 ms  68.969 ms  ┊ GC (min  max): 0.00%  97.74%
 Time  (median):     1.456 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.521 ms ±  1.236 ms  ┊ GC (mean ± σ):  3.64% ±  6.95%

    ▁▇█▄▂   ▁                                                 
  ▂▅██████▇████▆▆▆▄▄▄▃▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▂▂ ▃
  1.42 ms        Histogram: frequency by time        1.67 ms <

 Memory estimate: 771.65 KiB, allocs estimate: 9099.

Policy Run

At Every Iteration

--------------------------------------------------------------------------------------
 Time       Iter       LB           UB           Precision       # Alphas   # Beliefs 
--------------------------------------------------------------------------------------
 0.00       0          -28.3501795  -15.6037314  12.7464480985   2          30        
 0.00       10         -16.3057342  -16.2819897  0.0237444537    2          98        
 0.00       20         -16.3054833  -16.3024060  0.0030772894    2          133       
--------------------------------------------------------------------------------------
 0.00       28         -16.3054833  -16.3045949  0.0008883805    2          134       
--------------------------------------------------------------------------------------

Only when should_prune_alphas

--------------------------------------------------------------------------------------
 Time       Iter       LB           UB           Precision       # Alphas   # Beliefs 
--------------------------------------------------------------------------------------
 0.00       0          -28.3501795  -15.6037314  12.7464480985   2          30        
 0.00       10         -16.3057342  -16.2819897  0.0237444537    2          98        
 0.00       20         -16.3054833  -16.3024060  0.0030772894    2          133       
--------------------------------------------------------------------------------------
 0.00       28         -16.3054833  -16.3045949  0.0008883805    2          134       
--------------------------------------------------------------------------------------

TigerPOMDP

Settings:

  • epsilon: 0.1
  • precision: 0.001
  • delta: 0.0001
  • max_steps: 50 (for benchmarking)
  • max_time: 5.0 (for policy run)

Benchmark

At Every Iteration

BenchmarkTools.Trial: 188 samples with 1 evaluation.
 Range (min  max):  25.590 ms  94.617 ms  ┊ GC (min  max): 0.00%  71.50%
 Time  (median):     26.162 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   26.683 ms ±  5.040 ms  ┊ GC (mean ± σ):  1.70% ±  5.60%

   ▁ ▂▅▆██▅▁                                                   
  ▆█▇███████▁▅█▆▅▆▅▁▅▅▁▁▁▅▁▁▁▅▁▁▅▁▁▁▅▅▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▁▁▁▅ ▅
  25.6 ms      Histogram: log(frequency) by time      30.5 ms <

 Memory estimate: 4.02 MiB, allocs estimate: 39898.

Only when should_prune_alphas

BenchmarkTools.Trial: 184 samples with 1 evaluation.
 Range (min  max):  26.127 ms  91.703 ms  ┊ GC (min  max): 0.00%  71.08%
 Time  (median):     26.702 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   27.259 ms ±  4.871 ms  ┊ GC (mean ± σ):  1.89% ±  5.92%

      ▄▇█▆                                                     
  ▆▄▆██████▆▇▇▁▆▄█▁▁▆▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▆▁▁▁▁▄▄ ▄
  26.1 ms      Histogram: log(frequency) by time      31.5 ms <

 Memory estimate: 4.81 MiB, allocs estimate: 40544.

Policy Run

At Every Iteration

--------------------------------------------------------------------------------------
 Time       Iter       LB           UB           Precision       # Alphas   # Beliefs 
--------------------------------------------------------------------------------------
 0.00       0          -10.7779872  87.0980496   97.8760368879   6          44        
 0.01       10         14.2589622   51.5049954   37.2460332328   5          464       
...     
 0.02       40         19.3709835   19.6684906   0.2975071211    5          463       
 0.03       50         19.3713674   19.3833253   0.0119579046    5          212       
--------------------------------------------------------------------------------------
 0.03       57         19.3713684   19.3722266   0.0008581957    5          467       
--------------------------------------------------------------------------------------

Only when should_prune_alphas

--------------------------------------------------------------------------------------
 Time       Iter       LB           UB           Precision       # Alphas   # Beliefs 
--------------------------------------------------------------------------------------
 0.00       0          -10.7779872  87.0980496   97.8760368879   6          44        
 0.01       10         14.2589622   51.5049954   37.2460332328   5          464       
...      
 0.02       40         19.3709835   19.6684906   0.2975071211    5          463       
 0.03       50         19.3713674   19.3833253   0.0119579046    5          212       
--------------------------------------------------------------------------------------
 0.03       57         19.3713684   19.3722266   0.0008581957    5          467       
--------------------------------------------------------------------------------------

RockSamplePOMDP(5,5)

Settings:

  • epsilon: 0.1
  • precision: 0.001
  • delta: 0.0001
  • max_steps: 50 (for benchmarking)
  • max_time: 5.0 (for policy run)

Benchmark

At Every Iteration

BenchmarkTools.Trial: 235 samples with 1 evaluation.
 Range (min  max):  19.918 ms  94.401 ms  ┊ GC (min  max): 0.00%  77.84%
 Time  (median):     20.551 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   21.292 ms ±  4.979 ms  ┊ GC (mean ± σ):  3.92% ±  7.37%

  ▁▄▅▆▄▄▆█▃▁▁ ▁                                                
  ███████████▅█▇▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▅▆▅▅▁▆▆▆▅█▁▇▁▅▁▁▆▅▆▅▅▁▇ ▆
  19.9 ms      Histogram: log(frequency) by time      25.3 ms <

 Memory estimate: 11.77 MiB, allocs estimate: 51341.

Only when should_prune_alphas

BenchmarkTools.Trial: 227 samples with 1 evaluation.
 Range (min  max):  21.107 ms  88.421 ms  ┊ GC (min  max): 0.00%  75.12%
 Time  (median):     21.362 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   22.097 ms ±  4.550 ms  ┊ GC (mean ± σ):  2.80% ±  6.42%

  ▃█▇▃ ▂▄▂ ▃▄                                                  
  ████▆███▄███▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▆▁▄▆▆▁▁▁▁▆▄▁▄▁▄▁▁▁▁▁▆▄▄▄ ▆
  21.1 ms      Histogram: log(frequency) by time      26.2 ms <

 Memory estimate: 8.71 MiB, allocs estimate: 41919.

Policy Run

At Every Iteration

--------------------------------------------------------------------------------------
 Time       Iter       LB           UB           Precision       # Alphas   # Beliefs 
--------------------------------------------------------------------------------------
 0.00       0          15.2423183   18.3402991   3.0979807840    13         14        
 0.01       10         16.9264164   18.1564052   1.2299888405    50         85        
...     
 0.17       270        16.9264164   16.9318942   0.0054778577    140        270       
 0.17       280        16.9264164   16.9291465   0.0027301069    140        236       
--------------------------------------------------------------------------------------
 0.18       290        16.9264164   16.9273809   0.0009645004    140        103       
--------------------------------------------------------------------------------------

Only when should_prune_alphas

--------------------------------------------------------------------------------------
 Time       Iter       LB           UB           Precision       # Alphas   # Beliefs 
--------------------------------------------------------------------------------------
 0.00       0          15.2423183   18.3402991   3.0979807840    13         14        
 0.00       10         16.9264164   18.1564052   1.2299888405    50         85        
... 
 0.17       270        16.9264164   16.9318942   0.0054778577    140        270       
 0.18       280        16.9264164   16.9291465   0.0027301069    146        236       
--------------------------------------------------------------------------------------
 0.18       290        16.9264164   16.9273809   0.0009645004    154        103       
--------------------------------------------------------------------------------------

TagPOMDP

Settings:

  • epsilon: 0.1
  • precision: 0.001
  • delta: 0.0001
  • max_steps: 50 (for benchmarking)
  • max_time: 5.0 (for policy run)

Benchmark

At Every Iteration

BenchmarkTools.Trial: 6 samples with 1 evaluation.
 Range (min  max):  901.824 ms  983.309 ms  ┊ GC (min  max): 0.00%  8.59%
 Time  (median):     906.289 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   924.674 ms ±  33.572 ms  ┊ GC (mean ± σ):  2.44% ± 3.75%

  █ ▁ ▁                             ▁                         ▁  
  █▁█▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  902 ms           Histogram: frequency by time          983 ms <

 Memory estimate: 196.05 MiB, allocs estimate: 899624.

Only when should_prune_alphas

BenchmarkTools.Trial: 5 samples with 1 evaluation.
 Range (min  max):  1.086 s    1.235 s  ┊ GC (min  max): 0.00%  8.29%
 Time  (median):     1.093 s              ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.120 s ± 64.031 ms  ┊ GC (mean ± σ):  1.83% ± 3.71%

  ███ █                                                   █  
  ███▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  1.09 s         Histogram: frequency by time        1.23 s <

 Memory estimate: 175.33 MiB, allocs estimate: 879666.

Policy Run

At Every Iteration

--------------------------------------------------------------------------------------
 Time       Iter       LB           UB           Precision       # Alphas   # Beliefs 
--------------------------------------------------------------------------------------
 0.01       0          -19.3713764  -4.6944443   14.6769321680   18         48        
 0.09       10         -12.1605923  -4.9155579   7.2450343269    108        345       
...    
 4.14       150        -11.1016246  -5.3819350   5.7196895946    608        2516      
 4.58       160        -11.1016246  -5.3932110   5.7084135463    612        2633      
--------------------------------------------------------------------------------------
 5.02       169        -11.1003199  -5.4068507   5.6934692423    641        2743      
--------------------------------------------------------------------------------------

Only when should_prune_alphas

--------------------------------------------------------------------------------------
 Time       Iter       LB           UB           Precision       # Alphas   # Beliefs 
--------------------------------------------------------------------------------------
 0.01       0          -19.3713764  -4.6944443   14.6769321680   18         48        
 0.10       10         -12.1605923  -4.9155579   7.2450343269    108        345       
...    
 4.31       140        -11.1202809  -5.3649098   5.7553711760    575        2380      
 4.78       150        -11.1016246  -5.3819350   5.7196895946    629        2516      
--------------------------------------------------------------------------------------
 5.02       156        -11.1016246  -5.3906492   5.7109754343    652        2579      
--------------------------------------------------------------------------------------

TagPOMDP

Settings:

  • epsilon: 0.1
  • precision: 0.001
  • delta: 0.0001
  • max_steps: 50 (for benchmarking)
  • max_time: 60.0 (for policy run)

Policy Run

At Every Iteration

--------------------------------------------------------------------------------------
 Time       Iter       LB           UB           Precision       # Alphas   # Beliefs 
--------------------------------------------------------------------------------------
 0.01       0          -19.3713764  -4.6944443   14.6769321680   18         48        
 0.09       10         -12.1605923  -4.9155579   7.2450343269    108        345       
 ...  
 57.21      550        -10.9227351  -5.7750679   5.1476671438    1100       7249      
 59.61      560        -10.9227351  -5.7818944   5.1408406919    1113       7351      
--------------------------------------------------------------------------------------
 60.08      563        -10.9227351  -5.7841025   5.1386325673    1114       7373      
--------------------------------------------------------------------------------------

Only when should_prune_alphas

--------------------------------------------------------------------------------------
 Time       Iter       LB           UB           Precision       # Alphas   # Beliefs 
--------------------------------------------------------------------------------------
 0.01       0          -19.3713764  -4.6944443   14.6769321680   18         48        
 0.11       10         -12.1605923  -4.9155579   7.2450343269    108        345       
 ...
 56.39      530        -10.9280846  -5.7611180   5.1669666193    1082       6906      
 59.07      540        -10.9280846  -5.7686969   5.1593877438    1152       7060      
--------------------------------------------------------------------------------------
 60.22      544        -10.9280846  -5.7698137   5.1582709201    1110       7119      
--------------------------------------------------------------------------------------

RockSamplePOMDP(15,10)

Settings:

  • epsilon: 0.1
  • precision: 0.001
  • delta: 0.0001
  • max_steps: 50 (for benchmarking)
  • max_time: 120.0 (for policy run)
  • init_lower: BlindLowerBound(9223372036854775807, 60.0, 0.001, Float64[], Float64[])
  • init_upper: FastInformedBound(9223372036854775807, 60.0, 0.001, 0.0, Float64[], Float64[])

Policy Run

At Every Iteration

--------------------------------------------------------------------------------------
 Time       Iter       LB           UB           Precision       # Alphas   # Beliefs 
--------------------------------------------------------------------------------------
 1.16       0          14.9526422   18.9521036   3.9994614354    31         36        
 23.02      10         15.5252008   18.4262322   2.9010313613    299        472       
 56.36      20         15.6821109   18.2229256   2.5408147418    526        819       
 106.26     30         15.7936995   18.0692991   2.2755996353    766        1124      
--------------------------------------------------------------------------------------
 123.90     34         15.8091284   18.0564031   2.2472746795    837        1236      
--------------------------------------------------------------------------------------

Only when should_prune_alphas

--------------------------------------------------------------------------------------
 Time       Iter       LB           UB           Precision       # Alphas   # Beliefs 
--------------------------------------------------------------------------------------
 1.16       0          14.9526422   18.9521036   3.9994614354    31         36        
 24.87      10         15.5252008   18.4262322   2.9010313613    299        472       
 66.25      20         15.6821109   18.2229256   2.5408147418    526        819       
 127.53     30         15.7936995   18.0692991   2.2755996353    766        1124      
--------------------------------------------------------------------------------------
 127.53     31         15.7936995   18.0692991   2.2755996353    766        1124      
--------------------------------------------------------------------------------------

if should_prune_alphas(tree)
prune_alpha!(tree, solver.delta)
end
Expand Down Expand Up @@ -48,60 +49,120 @@ function prune!(tree::SARSOPTree)
end
end

function belief_space_domination(α1, α2, B, δ)
a1_dominant = true
a2_dominant = true
for b ∈ B
!a1_dominant && !a2_dominant && return (false, false)
δV = intersection_distance(α1, α2, b)
δV ≤ δ && (a1_dominant = false)
δV ≥ -δ && (a2_dominant = false)
end
return a1_dominant, a2_dominant
end

@inline function intersection_distance(α1, α2, b)
s = 0.0
function intersection_distance(α1, α2, b)
dot_sum = 0.0
I,B = b.nzind, b.nzval
@inbounds for _i ∈ eachindex(I)
I, B = b.nzind, b.nzval
for _i ∈ eachindex(I)
i = I[_i]
diff = α1[i] - α2[i]
s += abs2(diff)
dot_sum += diff*B[_i]
dot_sum += (α1[i] - α2[i]) * B[_i]
end
s = 0.0
for i ∈ eachindex(α1, α2)
s += (α1[i] - α2[i])^2
end
return dot_sum / sqrt(s)
end

function prune_alpha!(tree::SARSOPTree, δ)
function prune_alpha!(tree::SARSOPTree, δ, eps=0.0)
Γ = tree.Γ
B_valid = tree.b[map(!,tree.b_pruned)]
pruned = falses(length(Γ))

# checking if α_i dominates α_j
for (i,α_i) ∈ enumerate(Γ)
pruned[i] && continue
for (j,α_j) ∈ enumerate(Γ)
(j ≤ i || pruned[j]) && continue
a1_dominant,a2_dominant = belief_space_domination(α_i, α_j, B_valid, δ)
#=
NOTE: α1 and α2 shouldn't technically be able to mutually dominate
i.e. a1_dominant and a2_dominant should never both be true.
But this does happen when α1 == α2 because intersection_distance returns NaN.
Current impl prunes α2 without doing an equality check, removing
the duplicate α. Could do equality check to short-circuit
belief_space_domination which would speed things up if we have
a lot of duplicates, but the equality check can slow things down
if α's are sufficiently diverse.
=#
if a1_dominant
pruned[j] = true
elseif a2_dominant
pruned[i] = true
break
B_valid = tree.b[map(!, tree.b_pruned)]

n_Γ = length(Γ)
n_B = length(B_valid)

dominant_indices_bools = falses(n_Γ)
dominant_vector_indices = Vector{Int}(undef, n_B)

# First, identify dominant alpha vectors
for b_idx in 1:n_B
max_value = -Inf
max_index = -1
for i in 1:n_Γ
value = dot(Γ[i], B_valid[b_idx])
if value > max_value
max_value = value
max_index = i
end
end
dominant_indices_bools[max_index] = true
dominant_vector_indices[b_idx] = max_index
end
deleteat!(Γ, pruned)

non_dominant_indices = findall(!, dominant_indices_bools)
n_non_dom = length(non_dominant_indices)
keep_non_dom = falses(n_non_dom)

for b_idx in 1:n_B
dom_vec_idx = dominant_vector_indices[b_idx]
for j in 1:n_non_dom
non_dom_idx = non_dominant_indices[j]
if keep_non_dom[j]
continue
end
intx_dist = intersection_distance(Γ[dom_vec_idx], Γ[non_dom_idx], B_valid[b_idx])
if !isnan(intx_dist) && (intx_dist + eps ≤ δ)
keep_non_dom[j] = true
end
end
end

non_dominant_indices = non_dominant_indices[.!keep_non_dom]
deleteat!(Γ, non_dominant_indices)
tree.prune_data.last_Γ_size = length(Γ)
end

function strictly_dominates(α1, α2, eps)
for ii in 1:length(α1)
if α1[ii] < α2[ii] - eps
return false
end
end
return true
end

function prune_strictly_dominated!(tree::SARSOPTree, eps=1e-10)
Γ = tree.Γ
Γ_new_idxs = Vector{Int}(undef, length(Γ))
keep = trues(length(Γ))

idx_count = 0
for (α_try_idx, α_try) in enumerate(Γ)
dominated = false
for jj in 1:idx_count
α_in_idx = Γ_new_idxs[jj]
α_in = Γ[α_in_idx]
if strictly_dominates(α_try, α_in, eps)
keep[jj] = false
elseif strictly_dominates(α_in, α_try, eps)
dominated = true
break
end
end
if !dominated
new_idx_count = 0
for jj in 1:idx_count
if keep[jj]
new_idx_count += 1
Γ_new_idxs[new_idx_count] = Γ_new_idxs[jj]
end
end
new_idx_count += 1
Γ_new_idxs[new_idx_count] = α_try_idx
idx_count = new_idx_count
fill!(keep, true)
end
end

resize!(Γ_new_idxs, idx_count)

to_delete = trues(length(Γ))
for idx in Γ_new_idxs
to_delete[idx] = false
end

for ii in length(Γ):-1:1
if to_delete[ii]
deleteat!(Γ, ii)
end
end
end
19 changes: 19 additions & 0 deletions test/prune.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
@testset "prune" begin
# NativeSARSOP.strictly_dominates
a1 = [1.0, 2.0, 3.0]
a2 = [1.0, 2.1, 2.9]
a3 = [0.9, 1.9, 2.9]
@test !NativeSARSOP.strictly_dominates(a1, a2, 1e-10)
@test NativeSARSOP.strictly_dominates(a1, a1, 1e-10)
@test NativeSARSOP.strictly_dominates(a1, a3, 1e-10)

# NativeSARSOP.intersection_distance
b = SparseVector([1.0, 0.0])
a1 = [1.0, 0.0]
a2 = [0.0, 1.0]
@test isapprox(NativeSARSOP.intersection_distance(a1, a2, b),
sqrt(0.5^2 + 0.5^2), atol=1e-10)

b = SparseVector([0.5, 0.5])
@test isapprox(NativeSARSOP.intersection_distance(a1, a2, b), 0.0, atol=1e-10)
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ include("sample.jl")

include("updater.jl")

include("prune.jl")

include("tree.jl")

@testset "Tiger POMDP" begin
Expand Down
Loading