Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions lectures/numpy_vs_numba_vs_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,10 @@ We'll apply a `lax.fori_loop`, which is a version of a for loop that can be comp
```{code-cell} ipython3
cpu = jax.devices("cpu")[0]

@partial(jax.jit, static_argnames=("n",), device=cpu)
# Pin the input to the CPU, which keeps the whole computation there
x0_cpu = jax.device_put(0.1, cpu)

@partial(jax.jit, static_argnames=("n",))
def qm_jax_fori(x0, n, α=4.0):

x = jnp.empty(n + 1).at[0].set(x0)
Expand All @@ -485,7 +488,7 @@ def qm_jax_fori(x0, n, α=4.0):
```

* We hold `n` static because it affects array size and hence JAX wants to specialize on its value in the compiled code.
* We pin to the CPU via `device=cpu` because this sequential workload consists of many small operations, leaving little opportunity for GPU parallelism.
* We pin the input to the CPU with `jax.device_put` (which keeps the whole computation on the CPU) because this sequential workload consists of many small operations, leaving little opportunity for GPU parallelism.

Important: Although `at[t].set` appears to create a new array at each step, inside a JIT-compiled function the compiler detects that the old array is no longer needed and performs the update in place!

Expand All @@ -494,7 +497,7 @@ Let's time it with the same parameters:
```{code-cell} ipython3
with qe.Timer():
# First run
x_jax = qm_jax_fori(0.1, n)
x_jax = qm_jax_fori(x0_cpu, n)
# Hold interpreter
x_jax.block_until_ready()
```
Expand All @@ -504,7 +507,7 @@ Let's run it again to eliminate compilation overhead:
```{code-cell} ipython3
with qe.Timer():
# Second run
x_jax = qm_jax_fori(0.1, n)
x_jax = qm_jax_fori(x0_cpu, n)
# Hold interpreter
x_jax.block_until_ready()
```
Expand All @@ -521,7 +524,7 @@ although the syntax is difficult to remember.


```{code-cell} ipython3
@partial(jax.jit, static_argnames=("n",), device=cpu)
@partial(jax.jit, static_argnames=("n",))
def qm_jax_scan(x0, n, α=4.0):
def update(x, t):
x_new = α * x * (1 - x)
Expand All @@ -538,7 +541,7 @@ Let's time it with the same parameters:
```{code-cell} ipython3
with qe.Timer():
# First run
x_jax = qm_jax_scan(0.1, n)
x_jax = qm_jax_scan(x0_cpu, n)
# Hold interpreter
x_jax.block_until_ready()
```
Expand All @@ -548,7 +551,7 @@ Let's run it again to eliminate compilation overhead:
```{code-cell} ipython3
with qe.Timer():
# Second run
x_jax = qm_jax_scan(0.1, n)
x_jax = qm_jax_scan(x0_cpu, n)
# Hold interpreter
x_jax.block_until_ready()
```
Expand Down
Loading