-
-
Notifications
You must be signed in to change notification settings - Fork 53
[optgrowth_fast] Migrate to JAX and check against style guide #655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
📖 Netlify Preview Ready! Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (ff7e42b) 📚 Changed Lecture Pages: optgrowth_fast |
|
📖 Netlify Preview Ready! Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (a4f7341) 📚 Changed Lecture Pages: optgrowth_fast |
|
📖 Netlify Preview Ready! Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (4b966c1) 📚 Changed Lecture Pages: optgrowth_fast |
|
📖 Netlify Preview Ready! Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (2b606ca) 📚 Changed Lecture Pages: optgrowth_fast |
|
📖 Netlify Preview Ready! Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (9be4a24) 📚 Changed Lecture Pages: optgrowth_fast |
|
📖 Netlify Preview Ready! Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (9f1c5e2) 📚 Changed Lecture Pages: optgrowth_fast |
…/lecture-python.myst into opt-growth-fast-review
|
📖 Netlify Preview Ready! Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (2cff4d8) 📚 Changed Lecture Pages: optgrowth_fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR migrates the "Optimal Growth II: Accelerating the Code" lecture from Numba-based JIT compilation to JAX-based implementation. The changes modernize the lecture to use JAX's functional programming paradigm and automatic differentiation capabilities.
Key changes:
- Replaces Numba's
@jitclasswith JAX'sNamedTuplefor model storage - Implements a custom golden section search optimizer compatible with JAX's control flow
- Updates all numerical operations from NumPy to JAX's
jax.numpy(jnp)
| args=(y, v, model)) | ||
| return c_star, v_max | ||
| v_greedy, v_new = jax.vmap(maximize_at_state)(y_grid) |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The jax.vmap call unpacks two return values, but maximize_at_state returns a tuple. JAX's vmap will vectorize this tuple, resulting in v_greedy and v_new each being a single array. However, the function is structured to return them separately. This works because vmap treats tuple outputs correctly, but it would be clearer to explicitly unpack: results = jax.vmap(maximize_at_state)(y_grid) followed by v_greedy, v_new = results[0], results[1] or use jax.vmap with out_axes=(0, 0) to make the intent explicit.
| v_greedy, v_new = jax.vmap(maximize_at_state)(y_grid) | |
| results = jax.vmap(maximize_at_state)(y_grid) | |
| v_greedy, v_new = results[0], results[1] |
| ξ = jr.normal(key, (ts_length - 1,)) | ||
| y = np.empty(ts_length) |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function mixes JAX random number generation (jr.normal) with NumPy arrays (np.empty, np.exp). For consistency and to maintain JAX's functional purity, consider using jnp.empty and converting to NumPy only if needed for output. Alternatively, document why NumPy is used here (e.g., for compatibility with matplotlib or performance reasons in the simulation loop).
This PR migrates
optgrowth_fastto JAX. The result runs 4 times faster thannumbaon CPU.