Skip to content

Commit 92fa955

Browse files
committed
new
1 parent 66b64f8 commit 92fa955

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

jaxlogreg.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from jax import jit,grad,vmap,device_put,random
22
import jax.numpy as jnp
33
from functools import partial
4-
import time
54

65

76
class JaxReg:
@@ -48,18 +47,13 @@ def mb_gd(self, W, X, y):
4847

4948
for e in range(num_epochs):
5049
shuffle_index = random.permutation(random.PRNGKey(e), N)
51-
start_time = time.time()
5250
for m in range(0, N, size_batch):
5351
i = shuffle_index[m:m + size_batch]
5452

5553
grads_b = loss_grad(W, X[i, :],
5654
y[i]) # 3D jax array of size (batch_size, p+1, k): gradients for each batch element
5755

5856
W -= eta * jnp.mean(grads_b, axis=0) # Update W with average over each batch
59-
60-
epoch_time = time.time() - start_time # Epoch timer
61-
if e % 10 == 0:
62-
print("Time to complete epoch", e, ":", epoch_time)
6357
return W
6458

6559
def predict(self, data):

0 commit comments

Comments
 (0)