Skip to content

[numpy_vs_numba_vs_jax] Fix deprecated device= argument on jax.jit#563

Merged
mmcky merged 1 commit into
mainfrom
fix-jax-jit-device-deprecation
Jun 19, 2026
Merged

[numpy_vs_numba_vs_jax] Fix deprecated device= argument on jax.jit#563
mmcky merged 1 commit into
mainfrom
fix-jax-jit-device-deprecation

Conversation

@mmcky

@mmcky mmcky commented Jun 19, 2026

Copy link
Copy Markdown
Contributor

Summary

The numpy_vs_numba_vs_jax lecture pins two jax.jit-compiled functions to the CPU using the device= argument of jax.jit. Recent JAX versions have deprecated the device/backend arguments to jit, so the lecture emits two DeprecationWarnings that are captured into the rendered HTML output:

DeprecationWarning: backend and device argument on jit is deprecated. You can use jax.device_put(..., jax.local_devices(backend="cpu")[0]) on the inputs to the jitted function to get the same behavior.

Fix

Pin the computation to the CPU the recommended way — commit the input to the CPU with jax.device_put — instead of the deprecated decorator argument:

  • Drop device=cpu from both @partial(jax.jit, ...) decorators (qm_jax_fori, qm_jax_scan).
  • Add x0_cpu = jax.device_put(0.1, cpu) and pass it at the call sites; committing the input to the CPU keeps the whole computation on the CPU.
  • Update the explanatory bullet to describe the new approach.

CPU execution (important for the timing narrative of this sequential workload) is unchanged; this only removes the deprecation warning from the HTML.

Context

Found while validating the anaconda=2026.06 bump (#562). This warning is independent of that bump — JAX is pip-installed separately — but it was the only remaining warning in the built HTML, so this cleans it up.

JAX has deprecated the device/backend arguments to jax.jit, which emitted
two DeprecationWarnings into the lecture's rendered HTML output:

    DeprecationWarning: backend and device argument on jit is deprecated.
    You can use jax.device_put(..., jax.local_devices(backend="cpu")[0])
    on the inputs to the jitted function to get the same behavior.

Pin the computation to the CPU the recommended way -- commit the input to
the CPU with jax.device_put -- instead of the deprecated decorator
argument. The behaviour (CPU execution for this sequential workload) is
unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 19, 2026 02:17

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Removes JAX DeprecationWarnings from the numpy_vs_numba_vs_jax lecture by updating how CPU pinning is done for jax.jit-compiled functions, aligning the lecture with current JAX guidance.

Changes:

  • Removed deprecated device=cpu usage from @partial(jax.jit, ...) decorators.
  • Added a CPU-placed input (x0_cpu = jax.device_put(0.1, cpu)) and updated call sites to pass it in, keeping the computation on CPU.
  • Updated the explanatory bullet to reflect the new “pin via input placement” approach.

@github-actions

Copy link
Copy Markdown

@github-actions github-actions Bot temporarily deployed to pull request June 19, 2026 02:24 Inactive
@mmcky

mmcky commented Jun 19, 2026

Copy link
Copy Markdown
Contributor Author

@HumphreyYang @jstac -- just FYI re: deprecations in JAX.

@jstac

jstac commented Jun 19, 2026

Copy link
Copy Markdown
Contributor

Thanks @mmcky , appreciated. pls go ahead.

@mmcky

mmcky commented Jun 19, 2026

Copy link
Copy Markdown
Contributor Author

Thanks @jstac. Merging now.

The only other alternative is to wrap functions using a context manager.

with jax.default_device(cpu):
   {code}

but I think this is cleaner.

@mmcky mmcky merged commit d37b1d8 into main Jun 19, 2026
6 checks passed
@mmcky mmcky deleted the fix-jax-jit-device-deprecation branch June 19, 2026 03:18
@mmcky

mmcky commented Jun 19, 2026

Copy link
Copy Markdown
Contributor Author

✅ Translation sync completed (zh-cn)

Target repo: QuantEcon/lecture-python-programming.zh-cn
Translation PR: QuantEcon/lecture-python-programming.zh-cn#68
Files synced (1):

  • lectures/numpy_vs_numba_vs_jax.md

@mmcky

mmcky commented Jun 19, 2026

Copy link
Copy Markdown
Contributor Author

✅ Translation sync completed (fa)

Target repo: QuantEcon/lecture-python-programming.fa
Translation PR: QuantEcon/lecture-python-programming.fa#128
Files synced (1):

  • lectures/numpy_vs_numba_vs_jax.md

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants