You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Logistic regression classifier using JAX to support GPU acceleration.
3
+
4
+
This class is an update of a logistic regression class used in my intro to machine learning course. The major difference is the handling of the gradient descent operations,
5
+
which were rewritten using jax's grad, jit, and vmap functions. The goal with this project is speed - I've found that using JaxReg with GPU acceleration gives a ~15x speed
0 commit comments