|
32 | 32 |
|
33 | 33 | def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, |
34 | 34 | pred_labels=None, title=None, normalize=False, |
35 | | - hide_zeros=False, x_tick_rotation=0, ax=None, |
| 35 | + hide_zeros=False, hide_counts=False, x_tick_rotation=0, ax=None, |
36 | 36 | figsize=None, cmap='Blues', title_fontsize="large", |
37 | 37 | text_fontsize="medium"): |
38 | 38 | """Generates confusion matrix plot from predictions and true labels |
@@ -65,6 +65,9 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, |
65 | 65 | hide_zeros (bool, optional): If True, does not plot cells containing a |
66 | 66 | value of zero. Defaults to False. |
67 | 67 |
|
| 68 | + hide_counts (bool, optional): If True, doe not overlay counts. |
| 69 | + Defaults to False. |
| 70 | +
|
68 | 71 | x_tick_rotation (int, optional): Rotates x-axis tick labels by the |
69 | 72 | specified angle. This is useful in cases where there are numerous |
70 | 73 | categories and the labels overlap each other. |
@@ -160,13 +163,15 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, |
160 | 163 | ax.set_yticklabels(true_classes, fontsize=text_fontsize) |
161 | 164 |
|
162 | 165 | thresh = cm.max() / 2. |
163 | | - for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): |
164 | | - if not (hide_zeros and cm[i, j] == 0): |
165 | | - ax.text(j, i, cm[i, j], |
166 | | - horizontalalignment="center", |
167 | | - verticalalignment="center", |
168 | | - fontsize=text_fontsize, |
169 | | - color="white" if cm[i, j] > thresh else "black") |
| 166 | + |
| 167 | + if not hide_counts: |
| 168 | + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): |
| 169 | + if not (hide_zeros and cm[i, j] == 0): |
| 170 | + ax.text(j, i, cm[i, j], |
| 171 | + horizontalalignment="center", |
| 172 | + verticalalignment="center", |
| 173 | + fontsize=text_fontsize, |
| 174 | + color="white" if cm[i, j] > thresh else "black") |
170 | 175 |
|
171 | 176 | ax.set_ylabel('True label', fontsize=text_fontsize) |
172 | 177 | ax.set_xlabel('Predicted label', fontsize=text_fontsize) |
|
0 commit comments