From f6cfeb9d5522664efdd8e9f38d07d82cb2b2c235 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 5 Jun 2024 07:53:06 -0700 Subject: [PATCH] Set random state when splitting data --- packages/scikit-learn/index.rst | 57 +++++++++++++++++---------------- requirements.txt | 2 +- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/packages/scikit-learn/index.rst b/packages/scikit-learn/index.rst index e9a2ac0b3..447cee5a8 100644 --- a/packages/scikit-learn/index.rst +++ b/packages/scikit-learn/index.rst @@ -570,7 +570,8 @@ One good method to keep in mind is Gaussian Naive Bayes >>> from sklearn.model_selection import train_test_split >>> # split the data into training and validation sets - >>> X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... digits.data, digits.target, random_state=42) >>> # train the model >>> clf = GaussianNB() @@ -581,9 +582,9 @@ One good method to keep in mind is Gaussian Naive Bayes >>> predicted = clf.predict(X_test) >>> expected = y_test >>> print(predicted) - [5 1 7 2 8 9 4 3 9 3 6 2 3 2 6 7 4 3 5 7 5 7 0 1 2 5 9 8 1 8...] + [6 9 3 7 2 2 5 8 5 2 1 1 7 0 4 8 3 7 8 8 4 3 9 7 5 6 3 5 6 3...] >>> print(expected) - [5 8 7 2 8 9 4 3 7 3 6 2 3 2 6 7 4 3 5 7 5 7 0 1 2 5 3 3 1 8...] + [6 9 3 7 2 1 5 2 5 2 1 9 4 0 4 2 3 7 8 8 4 3 9 7 5 6 3 5 6 3...] As above, we plot the digits with the predicted labels to get an idea of how well the classification is working. @@ -607,11 +608,11 @@ the number of matches:: >>> matches = (predicted == expected) >>> print(matches.sum()) - 371 + 385 >>> print(len(matches)) 450 >>> matches.sum() / float(len(matches)) - 0.82444... + 0.8555... We see that more than 80% of the 450 predictions match the input. But there are other more sophisticated metrics that can be used to judge the @@ -625,20 +626,20 @@ combines several measures and prints a table with the results:: >>> print(metrics.classification_report(expected, predicted)) precision recall f1-score support - 0 1.00 0.98 0.99 45 - 1 0.91 0.66 0.76 44 - 2 0.91 0.56 0.69 36 - 3 0.89 0.67 0.77 49 - 4 0.95 0.83 0.88 46 - 5 0.93 0.93 0.93 45 - 6 0.92 0.98 0.95 47 - 7 0.75 0.96 0.84 50 - 8 0.49 0.97 0.66 39 - 9 0.85 0.67 0.75 49 + 0 1.00 0.95 0.98 43 + 1 0.85 0.78 0.82 37 + 2 0.85 0.61 0.71 38 + 3 0.97 0.83 0.89 46 + 4 0.98 0.84 0.90 55 + 5 0.90 0.95 0.93 59 + 6 0.90 0.96 0.92 45 + 7 0.71 0.98 0.82 41 + 8 0.60 0.89 0.72 38 + 9 0.90 0.73 0.80 48 - accuracy 0.82 450 - macro avg 0.86 0.82 0.82 450 - weighted avg 0.86 0.82 0.83 450 + accuracy 0.86 450 + macro avg 0.87 0.85 0.85 450 + weighted avg 0.88 0.86 0.86 450 @@ -647,16 +648,16 @@ is a *confusion matrix*: it helps us visualize which labels are being interchanged in the classification errors:: >>> print(metrics.confusion_matrix(expected, predicted)) - [[44 0 0 0 0 0 0 0 0 1] - [ 0 29 0 0 0 0 1 6 6 2] - [ 0 1 20 1 0 0 0 0 14 0] - [ 0 0 0 33 0 2 0 1 11 2] - [ 0 0 0 0 38 1 2 4 1 0] - [ 0 0 0 0 0 42 1 0 2 0] - [ 0 0 0 0 0 0 46 0 1 0] - [ 0 0 0 0 1 0 0 48 0 1] - [ 0 1 0 0 0 0 0 0 38 0] - [ 0 1 2 3 1 0 0 5 4 33]] + [[41 0 0 0 0 1 0 1 0 0] + [ 0 29 2 0 0 0 0 0 4 2] + [ 0 2 23 0 0 0 1 0 12 0] + [ 0 0 1 38 0 1 0 0 5 1] + [ 0 0 0 0 46 0 2 7 0 0] + [ 0 0 0 0 0 56 1 1 0 1] + [ 0 0 0 0 1 1 43 0 0 0] + [ 0 0 0 0 0 1 0 40 0 0] + [ 0 2 0 0 0 0 0 2 34 0] + [ 0 1 1 1 0 2 1 5 2 35]] We see here that in particular, the numbers 1, 2, 3, and 9 are often being labeled 8. diff --git a/requirements.txt b/requirements.txt index 45cef3520..7c053b0e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ matplotlib==3.9.0 pandas==2.2.2 patsy==0.5.6 pyarrow==16.1.0 -scikit-learn==1.4.2 +scikit-learn==1.5.0 scikit-image==0.23.2 sympy==1.12.1 statsmodels==0.14.2