Skip to content

Commit 66b64f8

Browse files
committed
new
0 parents  commit 66b64f8

File tree

8 files changed

+145
-0
lines changed

8 files changed

+145
-0
lines changed

.idea/.gitignore

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/inspectionProfiles/Project_Default.xml

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/inspectionProfiles/profiles_settings.xml

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/jaxlogreg.iml

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/modules.xml

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/vcs.xml

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

jaxlogreg.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from jax import jit,grad,vmap,device_put,random
2+
import jax.numpy as jnp
3+
from functools import partial
4+
import time
5+
6+
7+
class JaxReg:
8+
"""
9+
Logistic regression classifier with GPU acceleration support through Google's JAX. The point of this class is fitting speed: I want this
10+
to fit a model for very large datasets (k49 in particular) as quickly as possible!
11+
12+
- jit compilation utilized in sigma and loss methods (strongest in sigma due to matrix mult.). We need to 'partial' the
13+
jit function because it is used within a class.
14+
15+
- jax.numpy (jnp) operations are JAX implementations of numpy functions.
16+
17+
- jax.grad used as the gradient function. Returns gradient with respect to first parameter.
18+
19+
- jax.vmap is used to 'vectorize' the jax.grad function. Used to compute gradient of batch elements at once, in parallel.
20+
"""
21+
22+
def __init__(self, learning_rate=.001, num_epochs=50, size_batch=20):
23+
self.learning_rate = learning_rate
24+
self.num_epochs = num_epochs
25+
self.size_batch = size_batch
26+
27+
def fit(self, data, y):
28+
self.K = max(y) + 1
29+
ones = jnp.ones((data.shape[0], 1))
30+
X = jnp.concatenate((ones, data), axis=1)
31+
W = jnp.zeros((jnp.shape(X)[1], max(y) + 1))
32+
33+
self.coeff = self.mb_gd(W, X, y)
34+
35+
# New mini-batch gradient descent function (because jitted functions require arrays which do not change shape)
36+
def mb_gd(self, W, X, y):
37+
num_epochs = self.num_epochs
38+
size_batch = self.size_batch
39+
eta = self.learning_rate
40+
N = X.shape[0]
41+
42+
# Define the gradient function using jit, vmap, and the jax's own gradient function, grad.
43+
# vmap is especially useful for mini-batch GD since we compute all gradients of the batch at once, in parallel.
44+
# Special paramaters in_axes,out_axes define the axis of the input paramters (W, X, y) and output (gradients of batches)
45+
# upon which to vectorize. grads_b = loss_grad(W, X_batch, y_batch) has shape (batch_size, p+1, k) for p variables and k classes.
46+
47+
loss_grad = jit(vmap(grad(self.loss), in_axes=(None, 0, 0), out_axes=0))
48+
49+
for e in range(num_epochs):
50+
shuffle_index = random.permutation(random.PRNGKey(e), N)
51+
start_time = time.time()
52+
for m in range(0, N, size_batch):
53+
i = shuffle_index[m:m + size_batch]
54+
55+
grads_b = loss_grad(W, X[i, :],
56+
y[i]) # 3D jax array of size (batch_size, p+1, k): gradients for each batch element
57+
58+
W -= eta * jnp.mean(grads_b, axis=0) # Update W with average over each batch
59+
60+
epoch_time = time.time() - start_time # Epoch timer
61+
if e % 10 == 0:
62+
print("Time to complete epoch", e, ":", epoch_time)
63+
return W
64+
65+
def predict(self, data):
66+
ones = jnp.ones((data.shape[0], 1))
67+
X = jnp.concatenate((ones, data), axis=1) # Augment to account for intercept
68+
W = self.coeff
69+
y_pred = jnp.argmax(self.sigma(X, W),
70+
axis=1) # Predicted class is largest probability returned by softmax array
71+
return y_pred
72+
73+
def score(self, data, y_true):
74+
ones = jnp.ones((data.shape[0], 1))
75+
X = jnp.concatenate((ones, data), axis=1)
76+
y_pred = self.predict(data)
77+
acc = jnp.mean(y_pred == y_true)
78+
return acc
79+
80+
# jitting 'sigma' is the biggest speed-up compared to the original implementation
81+
@partial(jit, static_argnums=0)
82+
def sigma(self, X, W):
83+
if X.ndim == 1:
84+
X = jnp.reshape(X, (-1, X.shape[0])) # jax.grad seems to necessitate a reshape: X -> (1,p+1)
85+
s = jnp.exp(jnp.matmul(X, W))
86+
total = jnp.sum(s, axis=1).reshape(-1, 1)
87+
return s / total
88+
89+
@partial(jit, static_argnums=0)
90+
def loss(self, W, X, y):
91+
f_value = self.sigma(X, W)
92+
loss_vector = jnp.zeros(X.shape[0])
93+
for k in range(self.K):
94+
loss_vector += jnp.log(f_value + 1e-10)[:, k] * (y == k)
95+
return -jnp.mean(loss_vector)

0 commit comments

Comments
 (0)