Skip to content

Commit 19af1c9

Browse files
echan5reiinakano
authored andcommitted
add hide_counts parameter to plot_confusion_matrix() (#90)
* add hide_counts parameter to plot_confusion_matrix() * add test for hide_counts argument in plot_confusion_matrix()
1 parent 1ee71a0 commit 19af1c9

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

scikitplot/metrics.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
3434
pred_labels=None, title=None, normalize=False,
35-
hide_zeros=False, x_tick_rotation=0, ax=None,
35+
hide_zeros=False, hide_counts=False, x_tick_rotation=0, ax=None,
3636
figsize=None, cmap='Blues', title_fontsize="large",
3737
text_fontsize="medium"):
3838
"""Generates confusion matrix plot from predictions and true labels
@@ -65,6 +65,9 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
6565
hide_zeros (bool, optional): If True, does not plot cells containing a
6666
value of zero. Defaults to False.
6767
68+
hide_counts (bool, optional): If True, doe not overlay counts.
69+
Defaults to False.
70+
6871
x_tick_rotation (int, optional): Rotates x-axis tick labels by the
6972
specified angle. This is useful in cases where there are numerous
7073
categories and the labels overlap each other.
@@ -160,13 +163,15 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
160163
ax.set_yticklabels(true_classes, fontsize=text_fontsize)
161164

162165
thresh = cm.max() / 2.
163-
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
164-
if not (hide_zeros and cm[i, j] == 0):
165-
ax.text(j, i, cm[i, j],
166-
horizontalalignment="center",
167-
verticalalignment="center",
168-
fontsize=text_fontsize,
169-
color="white" if cm[i, j] > thresh else "black")
166+
167+
if not hide_counts:
168+
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
169+
if not (hide_zeros and cm[i, j] == 0):
170+
ax.text(j, i, cm[i, j],
171+
horizontalalignment="center",
172+
verticalalignment="center",
173+
fontsize=text_fontsize,
174+
color="white" if cm[i, j] > thresh else "black")
170175

171176
ax.set_ylabel('True label', fontsize=text_fontsize)
172177
ax.set_xlabel('Predicted label', fontsize=text_fontsize)

scikitplot/tests/test_metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ def test_labels(self):
5959
preds = clf.predict(self.X)
6060
plot_confusion_matrix(self.y, preds, labels=[0, 1, 2])
6161

62+
def test_hide_counts(self):
63+
np.random.seed(0)
64+
clf = LogisticRegression()
65+
clf.fit(self.X, self.y)
66+
preds = clf.predict(self.X)
67+
plot_confusion_matrix(self.y, preds, hide_counts=True)
68+
6269
def test_true_pred_labels(self):
6370
np.random.seed(0)
6471
clf = LogisticRegression()

0 commit comments

Comments
 (0)