From 5581a30c05258251fc23d6474dfce7bb6ed6e3c6 Mon Sep 17 00:00:00 2001 From: kazewong Date: Wed, 15 Nov 2023 16:13:50 -0500 Subject: [PATCH] update intro --- reveal/SciwareJax2023/intro.html | 147 ++++++++++++++++++++----------- 1 file changed, 95 insertions(+), 52 deletions(-) diff --git a/reveal/SciwareJax2023/intro.html b/reveal/SciwareJax2023/intro.html index f4a7fa50e97..d78fbc236c2 100644 --- a/reveal/SciwareJax2023/intro.html +++ b/reveal/SciwareJax2023/intro.html @@ -1,72 +1,115 @@
-

Jax -

+

Jax - Good, Better, Atrocious

Kaze Wong

-
- +
+
-
-

O4 has started a couple months ago

- - -
+

Infrastructure around AI is just as cool

+
+

Jax

-

Events are coming in hot!

- +
+ +
    +
  1. Autodiff
  2. +
  3. JIT compilation
  4. +
  5. Simple vectorization
  6. +
  7. GPU with XLA
  8. +
+
+
+ + +
-

The future is on the horizon

-
-
-
- -
-
- - -
-
-
-
- -
-
- -
-
-
+

Jax basic - your normal python

+

+        import jax.numpy as jnp
+
+        def f(x):
+            return x ** x
+
+        x = jnp.arange(1, 10)
+        f(x)
+    
-

Machine learning in GW

-
-
-
-

Search

- -
-
-
-
-

Simulation

- -
-
-
-

Inference

- -
-
+

Jax basic - grad

+

+        import jax.numpy as jnp
+        import jax
+
+        def f(x):
+            return x ** x
+
+        x = jnp.arange(1,10.)
+        df = jax.grad(f)
+        print("Check grad(f): ",df(3.) == (1+jnp.log(3.))*f(3.))
+        print("Try grad of f on array: ", df(x))
+    
+
+

Jax basic - vmap

+

+        import jax.numpy as jnp
+        import jax
+
+        def f(x):
+            return x ** x
+
+        x = jnp.arange(1, 10.)
+        df = jax.vmap(jax.grad(f))
+        print("Try grad of f on array: ", df(x))
+    
+
+ +
+

Jax basic - jit

+

+        import jax.numpy as jnp
+        import jax
+
+        def f(x):
+            return x * x + 2 * x
+
+        x = jnp.ones((5000,5000))
+        fast_f = jax.jit(f)
+        print("Bechmarking f(x)...")
+        %timeit f(x)
+        print("Bechmarking fast_f(x)...")
+        %timeit fast_f(x)
+    
+
+ +
+

Jax basic - EZ GPU

+

+        import jax.numpy as jnp
+        import jax
+
+        def f(x):
+            return x * x + 2 * x
+
+        x = jnp.ones((5000,5000))
+        cpu_f = jax.jit(f, backend="cpu")
+        gpu_f = jax.jit(f, backend="gpu")
+
+        print("Bechmarking cpu_f(x)...")
+        %timeit cpu_f(x)
+
+        print("Bechmarking gpu_f(x)...")
+        %timeit gpu_f(x)
+    
+
\ No newline at end of file