Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions packages/scikit-learn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,8 @@ One good method to keep in mind is Gaussian Naive Bayes
>>> from sklearn.model_selection import train_test_split

>>> # split the data into training and validation sets
>>> X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target)
>>> X_train, X_test, y_train, y_test = train_test_split(
... digits.data, digits.target, random_state=42)

>>> # train the model
>>> clf = GaussianNB()
Expand All @@ -581,9 +582,9 @@ One good method to keep in mind is Gaussian Naive Bayes
>>> predicted = clf.predict(X_test)
>>> expected = y_test
>>> print(predicted)
[5 1 7 2 8 9 4 3 9 3 6 2 3 2 6 7 4 3 5 7 5 7 0 1 2 5 9 8 1 8...]
[6 9 3 7 2 2 5 8 5 2 1 1 7 0 4 8 3 7 8 8 4 3 9 7 5 6 3 5 6 3...]
>>> print(expected)
[5 8 7 2 8 9 4 3 7 3 6 2 3 2 6 7 4 3 5 7 5 7 0 1 2 5 3 3 1 8...]
[6 9 3 7 2 1 5 2 5 2 1 9 4 0 4 2 3 7 8 8 4 3 9 7 5 6 3 5 6 3...]

As above, we plot the digits with the predicted labels to get an idea of
how well the classification is working.
Expand All @@ -607,11 +608,11 @@ the number of matches::

>>> matches = (predicted == expected)
>>> print(matches.sum())
371
385
>>> print(len(matches))
450
>>> matches.sum() / float(len(matches))
0.82444...
0.8555...

We see that more than 80% of the 450 predictions match the input. But
there are other more sophisticated metrics that can be used to judge the
Expand All @@ -625,20 +626,20 @@ combines several measures and prints a table with the results::
>>> print(metrics.classification_report(expected, predicted))
precision recall f1-score support
<BLANKLINE>
0 1.00 0.98 0.99 45
1 0.91 0.66 0.76 44
2 0.91 0.56 0.69 36
3 0.89 0.67 0.77 49
4 0.95 0.83 0.88 46
5 0.93 0.93 0.93 45
6 0.92 0.98 0.95 47
7 0.75 0.96 0.84 50
8 0.49 0.97 0.66 39
9 0.85 0.67 0.75 49
0 1.00 0.95 0.98 43
1 0.85 0.78 0.82 37
2 0.85 0.61 0.71 38
3 0.97 0.83 0.89 46
4 0.98 0.84 0.90 55
5 0.90 0.95 0.93 59
6 0.90 0.96 0.92 45
7 0.71 0.98 0.82 41
8 0.60 0.89 0.72 38
9 0.90 0.73 0.80 48
<BLANKLINE>
accuracy 0.82 450
macro avg 0.86 0.82 0.82 450
weighted avg 0.86 0.82 0.83 450
accuracy 0.86 450
macro avg 0.87 0.85 0.85 450
weighted avg 0.88 0.86 0.86 450
<BLANKLINE>


Expand All @@ -647,16 +648,16 @@ is a *confusion matrix*: it helps us visualize which labels are being
interchanged in the classification errors::

>>> print(metrics.confusion_matrix(expected, predicted))
[[44 0 0 0 0 0 0 0 0 1]
[ 0 29 0 0 0 0 1 6 6 2]
[ 0 1 20 1 0 0 0 0 14 0]
[ 0 0 0 33 0 2 0 1 11 2]
[ 0 0 0 0 38 1 2 4 1 0]
[ 0 0 0 0 0 42 1 0 2 0]
[ 0 0 0 0 0 0 46 0 1 0]
[ 0 0 0 0 1 0 0 48 0 1]
[ 0 1 0 0 0 0 0 0 38 0]
[ 0 1 2 3 1 0 0 5 4 33]]
[[41 0 0 0 0 1 0 1 0 0]
[ 0 29 2 0 0 0 0 0 4 2]
[ 0 2 23 0 0 0 1 0 12 0]
[ 0 0 1 38 0 1 0 0 5 1]
[ 0 0 0 0 46 0 2 7 0 0]
[ 0 0 0 0 0 56 1 1 0 1]
[ 0 0 0 0 1 1 43 0 0 0]
[ 0 0 0 0 0 1 0 40 0 0]
[ 0 2 0 0 0 0 0 2 34 0]
[ 0 1 1 1 0 2 1 5 2 35]]

We see here that in particular, the numbers 1, 2, 3, and 9 are often
being labeled 8.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ matplotlib==3.9.0
pandas==2.2.2
patsy==0.5.6
pyarrow==16.1.0
scikit-learn==1.4.2
scikit-learn==1.5.0
scikit-image==0.23.2
sympy==1.12.1
statsmodels==0.14.2
Expand Down