Skip to content

Commit

Permalink
Deploying to gh-pages from @ 3dc3809 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Mar 25, 2024
1 parent 34e8473 commit c822362
Show file tree
Hide file tree
Showing 53 changed files with 188 additions and 96 deletions.
Binary file modified .doctrees/autoapi/blackjax/smc/base/index.doctree
Binary file not shown.
Binary file modified .doctrees/autoapi/blackjax/smc/index.doctree
Binary file not shown.
Binary file modified .doctrees/autoapi/blackjax/smc/inner_kernel_tuning/index.doctree
Binary file not shown.
Binary file modified .doctrees/environment.pickle
Binary file not shown.
Binary file modified .doctrees/examples/howto_custom_gradients.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_metropolis_within_gibbs.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_sample_multiple_chains.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_use_aesara.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_use_numpyro.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_use_oryx.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_use_pymc.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_use_tfp.doctree
Binary file not shown.
Binary file modified .doctrees/examples/quickstart.doctree
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions _modules/blackjax/_version.html
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ <h1>Source code for blackjax._version</h1><div class="highlight"><pre>
<span class="n">version_tuple</span><span class="p">:</span> <span class="n">VERSION_TUPLE</span></div>


<span class="n">__version__</span> <span class="o">=</span> <span class="n">version</span> <span class="o">=</span> <span class="s1">&#39;0.1.dev1+g2ccdfb0&#39;</span>
<span class="n">__version_tuple__</span> <span class="o">=</span> <span class="n">version_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">&#39;dev1&#39;</span><span class="p">,</span> <span class="s1">&#39;g2ccdfb0&#39;</span><span class="p">)</span>
<span class="n">__version__</span> <span class="o">=</span> <span class="n">version</span> <span class="o">=</span> <span class="s1">&#39;0.1.dev1+g3dc3809&#39;</span>
<span class="n">__version_tuple__</span> <span class="o">=</span> <span class="n">version_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">&#39;dev1&#39;</span><span class="p">,</span> <span class="s1">&#39;g3dc3809&#39;</span><span class="p">)</span>
</pre></div>

</article>
Expand Down
26 changes: 22 additions & 4 deletions _modules/blackjax/smc/base.html
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,10 @@ <h1>Source code for blackjax.smc.base</h1><div class="highlight"><pre>
<div class="viewcode-block" id="SMCState.weights">
<a class="viewcode-back" href="../../../autoapi/blackjax/smc/base/index.html#blackjax.smc.base.SMCState.weights">[docs]</a>
<span class="n">weights</span><span class="p">:</span> <span class="n">Array</span></div>

<div class="viewcode-block" id="SMCState.update_parameters">
<a class="viewcode-back" href="../../../autoapi/blackjax/smc/base/index.html#blackjax.smc.base.SMCState.update_parameters">[docs]</a>
<span class="n">update_parameters</span><span class="p">:</span> <span class="n">ArrayTree</span></div>
</div>


Expand Down Expand Up @@ -450,12 +454,12 @@ <h1>Source code for blackjax.smc.base</h1><div class="highlight"><pre>

<div class="viewcode-block" id="init">
<a class="viewcode-back" href="../../../autoapi/blackjax/smc/base/index.html#blackjax.smc.base.init">[docs]</a>
<span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="n">particles</span><span class="p">:</span> <span class="n">ArrayLikeTree</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="n">particles</span><span class="p">:</span> <span class="n">ArrayLikeTree</span><span class="p">,</span> <span class="n">init_update_params</span><span class="p">):</span>
<span class="c1"># Infer the number of particles from the size of the leading dimension of</span>
<span class="c1"># the first leaf of the inputted PyTree.</span>
<span class="n">num_particles</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_flatten</span><span class="p">(</span><span class="n">particles</span><span class="p">)[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">weights</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">num_particles</span><span class="p">)</span> <span class="o">/</span> <span class="n">num_particles</span>
<span class="k">return</span> <span class="n">SMCState</span><span class="p">(</span><span class="n">particles</span><span class="p">,</span> <span class="n">weights</span><span class="p">)</span></div>
<span class="k">return</span> <span class="n">SMCState</span><span class="p">(</span><span class="n">particles</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="n">init_update_params</span><span class="p">)</span></div>



Expand Down Expand Up @@ -531,17 +535,31 @@ <h1>Source code for blackjax.smc.base</h1><div class="highlight"><pre>
<span class="n">particles</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">[</span><span class="n">resampling_idx</span><span class="p">],</span> <span class="n">state</span><span class="o">.</span><span class="n">particles</span><span class="p">)</span>

<span class="n">keys</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">updating_key</span><span class="p">,</span> <span class="n">num_resampled</span><span class="p">)</span>
<span class="n">particles</span><span class="p">,</span> <span class="n">update_info</span> <span class="o">=</span> <span class="n">update_fn</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="n">particles</span><span class="p">)</span>
<span class="n">particles</span><span class="p">,</span> <span class="n">update_info</span> <span class="o">=</span> <span class="n">update_fn</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="n">particles</span><span class="p">,</span> <span class="n">state</span><span class="o">.</span><span class="n">update_parameters</span><span class="p">)</span>

<span class="n">log_weights</span> <span class="o">=</span> <span class="n">weight_fn</span><span class="p">(</span><span class="n">particles</span><span class="p">)</span>
<span class="n">logsum_weights</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">scipy</span><span class="o">.</span><span class="n">special</span><span class="o">.</span><span class="n">logsumexp</span><span class="p">(</span><span class="n">log_weights</span><span class="p">)</span>
<span class="n">normalizing_constant</span> <span class="o">=</span> <span class="n">logsum_weights</span> <span class="o">-</span> <span class="n">jnp</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">num_particles</span><span class="p">)</span>
<span class="n">weights</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_weights</span> <span class="o">-</span> <span class="n">logsum_weights</span><span class="p">)</span>

<span class="k">return</span> <span class="n">SMCState</span><span class="p">(</span><span class="n">particles</span><span class="p">,</span> <span class="n">weights</span><span class="p">),</span> <span class="n">SMCInfo</span><span class="p">(</span>
<span class="k">return</span> <span class="n">SMCState</span><span class="p">(</span><span class="n">particles</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="n">state</span><span class="o">.</span><span class="n">update_parameters</span><span class="p">),</span> <span class="n">SMCInfo</span><span class="p">(</span>
<span class="n">resampling_idx</span><span class="p">,</span> <span class="n">normalizing_constant</span><span class="p">,</span> <span class="n">update_info</span>
<span class="p">)</span></div>



<div class="viewcode-block" id="extend_params">
<a class="viewcode-back" href="../../../autoapi/blackjax/smc/base/index.html#blackjax.smc.base.extend_params">[docs]</a>
<span class="k">def</span> <span class="nf">extend_params</span><span class="p">(</span><span class="n">n_particles</span><span class="p">,</span> <span class="n">params</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Given a dictionary of params, repeats them for every single particle. The expected</span>
<span class="sd"> usage is in cases where the aim is to repeat the same parameters for all chains within SMC.</span>
<span class="sd"> &quot;&quot;&quot;</span>

<span class="k">def</span> <span class="nf">extend</span><span class="p">(</span><span class="n">param</span><span class="p">):</span>
<span class="k">return</span> <span class="n">jnp</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">param</span><span class="p">)[</span><span class="kc">None</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">n_particles</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>

<span class="k">return</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">extend</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span></div>

</pre></div>

</article>
Expand Down
Loading

0 comments on commit c822362

Please sign in to comment.