[numpy_vs_numba_vs_jax] Fix deprecated device= argument on jax.jit#563
Merged
Conversation
JAX has deprecated the device/backend arguments to jax.jit, which emitted
two DeprecationWarnings into the lecture's rendered HTML output:
DeprecationWarning: backend and device argument on jit is deprecated.
You can use jax.device_put(..., jax.local_devices(backend="cpu")[0])
on the inputs to the jitted function to get the same behavior.
Pin the computation to the CPU the recommended way -- commit the input to
the CPU with jax.device_put -- instead of the deprecated decorator
argument. The behaviour (CPU execution for this sequential workload) is
unchanged.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Pull request overview
Removes JAX DeprecationWarnings from the numpy_vs_numba_vs_jax lecture by updating how CPU pinning is done for jax.jit-compiled functions, aligning the lecture with current JAX guidance.
Changes:
- Removed deprecated
device=cpuusage from@partial(jax.jit, ...)decorators. - Added a CPU-placed input (
x0_cpu = jax.device_put(0.1, cpu)) and updated call sites to pass it in, keeping the computation on CPU. - Updated the explanatory bullet to reflect the new “pin via input placement” approach.
Contributor
Author
|
@HumphreyYang @jstac -- just FYI re: deprecations in JAX. |
Contributor
|
Thanks @mmcky , appreciated. pls go ahead. |
Contributor
Author
|
Thanks @jstac. Merging now. The only other alternative is to wrap functions using a context manager. but I think this is cleaner. |
Contributor
Author
✅ Translation sync completed (zh-cn)Target repo: QuantEcon/lecture-python-programming.zh-cn
|
Contributor
Author
✅ Translation sync completed (fa)Target repo: QuantEcon/lecture-python-programming.fa
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The
numpy_vs_numba_vs_jaxlecture pins twojax.jit-compiled functions to the CPU using thedevice=argument ofjax.jit. Recent JAX versions have deprecated thedevice/backendarguments tojit, so the lecture emits twoDeprecationWarnings that are captured into the rendered HTML output:Fix
Pin the computation to the CPU the recommended way — commit the input to the CPU with
jax.device_put— instead of the deprecated decorator argument:device=cpufrom both@partial(jax.jit, ...)decorators (qm_jax_fori,qm_jax_scan).x0_cpu = jax.device_put(0.1, cpu)and pass it at the call sites; committing the input to the CPU keeps the whole computation on the CPU.CPU execution (important for the timing narrative of this sequential workload) is unchanged; this only removes the deprecation warning from the HTML.
Context
Found while validating the
anaconda=2026.06bump (#562). This warning is independent of that bump — JAX ispip-installed separately — but it was the only remaining warning in the built HTML, so this cleans it up.