|
30 | 30 | } |
31 | 31 |
|
32 | 32 | estimators['fm-3'] = clone(estimators['fm-2']).set_params(degree=3) |
| 33 | +estimators['fm-2-ada'] = clone(estimators['fm-2']).set_params( |
| 34 | + solver='adagrad', learning_rate=0.01, max_iter=20) |
| 35 | +estimators['fm-3-ada'] = clone(estimators['fm-3']).set_params( |
| 36 | + solver='adagrad', learning_rate=0.01, max_iter=20 |
| 37 | +) |
33 | 38 | estimators['polynet-3'] = (clone(estimators['polynet-2']) |
34 | 39 | .set_params(degree=3, n_components=10)) |
35 | 40 |
|
36 | 41 | if __name__ == '__main__': |
37 | 42 | data_train = fetch_20newsgroups_vectorized(subset="train") |
38 | 43 | data_test = fetch_20newsgroups_vectorized(subset="test") |
39 | | - X_train = sp.csc_matrix(data_train.data) |
40 | | - X_test = sp.csc_matrix(data_test.data) |
| 44 | + X_train_csc = sp.csc_matrix(data_train.data) |
| 45 | + X_test_csc = sp.csc_matrix(data_test.data) |
| 46 | + X_train_csr = sp.csr_matrix(data_train.data) |
| 47 | + X_test_csr = sp.csr_matrix(data_test.data) |
41 | 48 |
|
42 | 49 | y_train = data_train.target == 0 # atheism vs rest |
43 | 50 | y_test = data_test.target == 0 |
44 | 51 |
|
45 | 52 | print("20 newsgroups") |
46 | 53 | print("=============") |
47 | | - print("X_train.shape = {0}".format(X_train.shape)) |
48 | | - print("X_train.format = {0}".format(X_train.format)) |
49 | | - print("X_train.dtype = {0}".format(X_train.dtype)) |
| 54 | + print("X_train.shape = {0}".format(X_train_csr.shape)) |
| 55 | + print("X_train.dtype = {0}".format(X_train_csr.dtype)) |
50 | 56 | print("X_train density = {0}" |
51 | | - "".format(X_train.nnz / np.product(X_train.shape))) |
| 57 | + "".format(X_train_csr.nnz / np.product(X_train_csr.shape))) |
52 | 58 | print("y_train {0}".format(y_train.shape)) |
53 | | - print("X_test {0}".format(X_test.shape)) |
54 | | - print("X_test.format = {0}".format(X_test.format)) |
55 | | - print("X_test.dtype = {0}".format(X_test.dtype)) |
| 59 | + print("X_test {0}".format(X_test_csr.shape)) |
| 60 | + print("X_test.dtype = {0}".format(X_test_csr.dtype)) |
56 | 61 | print("y_test {0}".format(y_test.shape)) |
57 | 62 | print() |
58 | 63 |
|
|
62 | 67 |
|
63 | 68 | for name, clf in sorted(estimators.items()): |
64 | 69 | print("Training %s ... " % name, end="") |
| 70 | + if 'ada' in name: |
| 71 | + X_train, X_teest = X_train_csr, X_test_csr |
| 72 | + else: |
| 73 | + X_train, X_test = X_train_csc, X_test_csc |
65 | 74 | t0 = time() |
66 | 75 | clf.fit(X_train, y_train) |
67 | 76 | train_time[name] = time() - t0 |
|
0 commit comments