1+ from jax import jit ,grad ,vmap ,device_put ,random
2+ import jax .numpy as jnp
3+ from functools import partial
4+ import time
5+
6+
7+ class JaxReg :
8+ """
9+ Logistic regression classifier with GPU acceleration support through Google's JAX. The point of this class is fitting speed: I want this
10+ to fit a model for very large datasets (k49 in particular) as quickly as possible!
11+
12+ - jit compilation utilized in sigma and loss methods (strongest in sigma due to matrix mult.). We need to 'partial' the
13+ jit function because it is used within a class.
14+
15+ - jax.numpy (jnp) operations are JAX implementations of numpy functions.
16+
17+ - jax.grad used as the gradient function. Returns gradient with respect to first parameter.
18+
19+ - jax.vmap is used to 'vectorize' the jax.grad function. Used to compute gradient of batch elements at once, in parallel.
20+ """
21+
22+ def __init__ (self , learning_rate = .001 , num_epochs = 50 , size_batch = 20 ):
23+ self .learning_rate = learning_rate
24+ self .num_epochs = num_epochs
25+ self .size_batch = size_batch
26+
27+ def fit (self , data , y ):
28+ self .K = max (y ) + 1
29+ ones = jnp .ones ((data .shape [0 ], 1 ))
30+ X = jnp .concatenate ((ones , data ), axis = 1 )
31+ W = jnp .zeros ((jnp .shape (X )[1 ], max (y ) + 1 ))
32+
33+ self .coeff = self .mb_gd (W , X , y )
34+
35+ # New mini-batch gradient descent function (because jitted functions require arrays which do not change shape)
36+ def mb_gd (self , W , X , y ):
37+ num_epochs = self .num_epochs
38+ size_batch = self .size_batch
39+ eta = self .learning_rate
40+ N = X .shape [0 ]
41+
42+ # Define the gradient function using jit, vmap, and the jax's own gradient function, grad.
43+ # vmap is especially useful for mini-batch GD since we compute all gradients of the batch at once, in parallel.
44+ # Special paramaters in_axes,out_axes define the axis of the input paramters (W, X, y) and output (gradients of batches)
45+ # upon which to vectorize. grads_b = loss_grad(W, X_batch, y_batch) has shape (batch_size, p+1, k) for p variables and k classes.
46+
47+ loss_grad = jit (vmap (grad (self .loss ), in_axes = (None , 0 , 0 ), out_axes = 0 ))
48+
49+ for e in range (num_epochs ):
50+ shuffle_index = random .permutation (random .PRNGKey (e ), N )
51+ start_time = time .time ()
52+ for m in range (0 , N , size_batch ):
53+ i = shuffle_index [m :m + size_batch ]
54+
55+ grads_b = loss_grad (W , X [i , :],
56+ y [i ]) # 3D jax array of size (batch_size, p+1, k): gradients for each batch element
57+
58+ W -= eta * jnp .mean (grads_b , axis = 0 ) # Update W with average over each batch
59+
60+ epoch_time = time .time () - start_time # Epoch timer
61+ if e % 10 == 0 :
62+ print ("Time to complete epoch" , e , ":" , epoch_time )
63+ return W
64+
65+ def predict (self , data ):
66+ ones = jnp .ones ((data .shape [0 ], 1 ))
67+ X = jnp .concatenate ((ones , data ), axis = 1 ) # Augment to account for intercept
68+ W = self .coeff
69+ y_pred = jnp .argmax (self .sigma (X , W ),
70+ axis = 1 ) # Predicted class is largest probability returned by softmax array
71+ return y_pred
72+
73+ def score (self , data , y_true ):
74+ ones = jnp .ones ((data .shape [0 ], 1 ))
75+ X = jnp .concatenate ((ones , data ), axis = 1 )
76+ y_pred = self .predict (data )
77+ acc = jnp .mean (y_pred == y_true )
78+ return acc
79+
80+ # jitting 'sigma' is the biggest speed-up compared to the original implementation
81+ @partial (jit , static_argnums = 0 )
82+ def sigma (self , X , W ):
83+ if X .ndim == 1 :
84+ X = jnp .reshape (X , (- 1 , X .shape [0 ])) # jax.grad seems to necessitate a reshape: X -> (1,p+1)
85+ s = jnp .exp (jnp .matmul (X , W ))
86+ total = jnp .sum (s , axis = 1 ).reshape (- 1 , 1 )
87+ return s / total
88+
89+ @partial (jit , static_argnums = 0 )
90+ def loss (self , W , X , y ):
91+ f_value = self .sigma (X , W )
92+ loss_vector = jnp .zeros (X .shape [0 ])
93+ for k in range (self .K ):
94+ loss_vector += jnp .log (f_value + 1e-10 )[:, k ] * (y == k )
95+ return - jnp .mean (loss_vector )
0 commit comments