Skip to content

Conversation

@HumphreyYang
Copy link
Member

This PR migrates optgrowth_fast to JAX. The result runs 4 times faster than numba on CPU.

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (ff7e42b)

📚 Changed Lecture Pages: optgrowth_fast

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (a4f7341)

📚 Changed Lecture Pages: optgrowth_fast

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (4b966c1)

📚 Changed Lecture Pages: optgrowth_fast

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (2b606ca)

📚 Changed Lecture Pages: optgrowth_fast

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (9be4a24)

📚 Changed Lecture Pages: optgrowth_fast

@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (9f1c5e2)

📚 Changed Lecture Pages: optgrowth_fast

@HumphreyYang HumphreyYang marked this pull request as ready for review October 29, 2025 02:35
@github-actions
Copy link

📖 Netlify Preview Ready!

Preview URL: https://pr-655--sunny-cactus-210e3e.netlify.app (2cff4d8)

📚 Changed Lecture Pages: optgrowth_fast

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR migrates the "Optimal Growth II: Accelerating the Code" lecture from Numba-based JIT compilation to JAX-based implementation. The changes modernize the lecture to use JAX's functional programming paradigm and automatic differentiation capabilities.

Key changes:

  • Replaces Numba's @jitclass with JAX's NamedTuple for model storage
  • Implements a custom golden section search optimizer compatible with JAX's control flow
  • Updates all numerical operations from NumPy to JAX's jax.numpy (jnp)

args=(y, v, model))
return c_star, v_max
v_greedy, v_new = jax.vmap(maximize_at_state)(y_grid)
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The jax.vmap call unpacks two return values, but maximize_at_state returns a tuple. JAX's vmap will vectorize this tuple, resulting in v_greedy and v_new each being a single array. However, the function is structured to return them separately. This works because vmap treats tuple outputs correctly, but it would be clearer to explicitly unpack: results = jax.vmap(maximize_at_state)(y_grid) followed by v_greedy, v_new = results[0], results[1] or use jax.vmap with out_axes=(0, 0) to make the intent explicit.

Suggested change
v_greedy, v_new = jax.vmap(maximize_at_state)(y_grid)
results = jax.vmap(maximize_at_state)(y_grid)
v_greedy, v_new = results[0], results[1]

Copilot uses AI. Check for mistakes.
Comment on lines +410 to 411
ξ = jr.normal(key, (ts_length - 1,))
y = np.empty(ts_length)
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function mixes JAX random number generation (jr.normal) with NumPy arrays (np.empty, np.exp). For consistency and to maintain JAX's functional purity, consider using jnp.empty and converting to NumPy only if needed for output. Alternatively, document why NumPy is used here (e.g., for compatibility with matplotlib or performance reasons in the simulation loop).

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants