Skip to content

Commit 4d72ea8

Browse files
authored
update plot_roc_curve to plot_roc (#86)
1 parent 87df2b4 commit 4d72ea8

File tree

5 files changed

+207
-7
lines changed

5 files changed

+207
-7
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ predicted_probas = nb.predict_proba(X_test)
4141
# The magic happens here
4242
import matplotlib.pyplot as plt
4343
import scikitplot as skplt
44-
skplt.metrics.plot_roc_curve(y_test, predicted_probas)
44+
skplt.metrics.plot_roc(y_test, predicted_probas)
4545
plt.show()
4646
```
4747
![roc_curves](examples/roc_curves.png)

docs/metrics.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ Metrics Module (API Reference)
55
==============================
66

77
.. automodule:: scikitplot.metrics
8-
:members: plot_confusion_matrix, plot_roc_curve, plot_ks_statistic, plot_precision_recall, plot_silhouette, plot_calibration_curve, plot_cumulative_gain, plot_lift_curve
8+
:members: plot_confusion_matrix, plot_roc, plot_ks_statistic, plot_precision_recall, plot_silhouette, plot_calibration_curve, plot_cumulative_gain, plot_lift_curve

examples/plot_roc_curve.py renamed to examples/plot_roc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@
1313
nb = GaussianNB()
1414
nb.fit(X, y)
1515
probas = nb.predict_proba(X)
16-
skplt.metrics.plot_roc_curve(y_true=y, y_probas=probas)
16+
skplt.metrics.plot_roc(y_true=y, y_probas=probas)
1717
plt.show()

scikitplot/metrics.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
175175
return ax
176176

177177

178+
@deprecated('This will be removed in v0.5.0. Please use '
179+
'scikitplot.metrics.plot_roc instead.')
178180
def plot_roc_curve(y_true, y_probas, title='ROC Curves',
179181
curves=('micro', 'macro', 'each_class'),
180182
ax=None, figsize=None, cmap='nipy_spectral',
@@ -322,6 +324,138 @@ def plot_roc_curve(y_true, y_probas, title='ROC Curves',
322324
return ax
323325

324326

327+
def plot_roc(y_true, y_probas, title='ROC Curves',
328+
plot_micro=True, plot_macro=True, classes_to_plot=None,
329+
ax=None, figsize=None, cmap='nipy_spectral',
330+
title_fontsize="large", text_fontsize="medium"):
331+
"""Generates the ROC curves from labels and predicted scores/probabilities
332+
333+
Args:
334+
y_true (array-like, shape (n_samples)):
335+
Ground truth (correct) target values.
336+
337+
y_probas (array-like, shape (n_samples, n_classes)):
338+
Prediction probabilities for each class returned by a classifier.
339+
340+
title (string, optional): Title of the generated plot. Defaults to
341+
"ROC Curves".
342+
343+
plot_micro (boolean, optional): Plot the micro average ROC curve.
344+
Defaults to ``True``.
345+
346+
plot_macro (boolean, optional): Plot the macro average ROC curve.
347+
Defaults to ``True``.
348+
349+
classes_to_plot (list-like, optional): Classes for which the ROC
350+
curve should be plotted. e.g. [0, 'cold']. If given class does not exist,
351+
it will be ignored. If ``None``, all classes will be plotted. Defaults to
352+
``None``
353+
354+
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
355+
plot the curve. If None, the plot is drawn on a new set of axes.
356+
357+
figsize (2-tuple, optional): Tuple denoting figure size of the plot
358+
e.g. (6, 6). Defaults to ``None``.
359+
360+
cmap (string or :class:`matplotlib.colors.Colormap` instance, optional):
361+
Colormap used for plotting the projection. View Matplotlib Colormap
362+
documentation for available options.
363+
https://matplotlib.org/users/colormaps.html
364+
365+
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
366+
Use e.g. "small", "medium", "large" or integer-values. Defaults to
367+
"large".
368+
369+
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
370+
Use e.g. "small", "medium", "large" or integer-values. Defaults to
371+
"medium".
372+
373+
Returns:
374+
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
375+
drawn.
376+
377+
Example:
378+
>>> import scikitplot as skplt
379+
>>> nb = GaussianNB()
380+
>>> nb = nb.fit(X_train, y_train)
381+
>>> y_probas = nb.predict_proba(X_test)
382+
>>> skplt.metrics.plot_roc(y_test, y_probas)
383+
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
384+
>>> plt.show()
385+
386+
.. image:: _static/examples/plot_roc_curve.png
387+
:align: center
388+
:alt: ROC Curves
389+
"""
390+
y_true = np.array(y_true)
391+
y_probas = np.array(y_probas)
392+
393+
classes = np.unique(y_true)
394+
probas = y_probas
395+
396+
if classes_to_plot is None:
397+
classes_to_plot = classes
398+
399+
if ax is None:
400+
fig, ax = plt.subplots(1, 1, figsize=figsize)
401+
402+
ax.set_title(title, fontsize=title_fontsize)
403+
404+
fpr_dict = dict()
405+
tpr_dict = dict()
406+
407+
indices_to_plot = np.in1d(classes, classes_to_plot)
408+
for i, to_plot in enumerate(indices_to_plot):
409+
fpr_dict[i], tpr_dict[i], _ = roc_curve(y_true, probas[:, i],
410+
pos_label=classes[i])
411+
if to_plot:
412+
roc_auc = auc(fpr_dict[i], tpr_dict[i])
413+
color = plt.cm.get_cmap(cmap)(float(i) / len(classes))
414+
ax.plot(fpr_dict[i], tpr_dict[i], lw=2, color=color,
415+
label='ROC curve of class {0} (area = {1:0.2f})'
416+
''.format(classes[i], roc_auc))
417+
418+
if plot_micro:
419+
binarized_y_true = label_binarize(y_true, classes=classes)
420+
if len(classes) == 2:
421+
binarized_y_true = np.hstack(
422+
(1 - binarized_y_true, binarized_y_true))
423+
fpr, tpr, _ = roc_curve(binarized_y_true.ravel(), probas.ravel())
424+
roc_auc = auc(fpr, tpr)
425+
ax.plot(fpr, tpr,
426+
label='micro-average ROC curve '
427+
'(area = {0:0.2f})'.format(roc_auc),
428+
color='deeppink', linestyle=':', linewidth=4)
429+
430+
if plot_macro:
431+
# Compute macro-average ROC curve and ROC area
432+
# First aggregate all false positive rates
433+
all_fpr = np.unique(np.concatenate([fpr_dict[x] for x in range(len(classes))]))
434+
435+
# Then interpolate all ROC curves at this points
436+
mean_tpr = np.zeros_like(all_fpr)
437+
for i in range(len(classes)):
438+
mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i])
439+
440+
# Finally average it and compute AUC
441+
mean_tpr /= len(classes)
442+
roc_auc = auc(all_fpr, mean_tpr)
443+
444+
ax.plot(all_fpr, mean_tpr,
445+
label='macro-average ROC curve '
446+
'(area = {0:0.2f})'.format(roc_auc),
447+
color='navy', linestyle=':', linewidth=4)
448+
449+
ax.plot([0, 1], [0, 1], 'k--', lw=2)
450+
ax.set_xlim([0.0, 1.0])
451+
ax.set_ylim([0.0, 1.05])
452+
ax.set_xlabel('False Positive Rate', fontsize=text_fontsize)
453+
ax.set_ylabel('True Positive Rate', fontsize=text_fontsize)
454+
ax.tick_params(labelsize=text_fontsize)
455+
ax.legend(loc='lower right', fontsize=text_fontsize)
456+
return ax
457+
458+
325459
def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot',
326460
ax=None, figsize=None, title_fontsize="large",
327461
text_fontsize="medium"):
@@ -554,7 +688,7 @@ def plot_precision_recall(y_true, y_probas,
554688
"Precision-Recall curve".
555689
556690
plot_micro (boolean, optional): Plot the micro average ROC curve.
557-
Defaults to `True`.
691+
Defaults to ``True``.
558692
559693
classes_to_plot (list-like, optional): Classes for which the precision-recall
560694
curve should be plotted. e.g. [0, 'cold']. If given class does not exist,

scikitplot/tests/test_metrics.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from scikitplot.metrics import plot_confusion_matrix
1515
from scikitplot.metrics import plot_roc_curve
16+
from scikitplot.metrics import plot_roc
1617
from scikitplot.metrics import plot_ks_statistic
1718
from scikitplot.metrics import plot_precision_recall_curve
1819
from scikitplot.metrics import plot_precision_recall
@@ -156,6 +157,72 @@ def test_array_like(self):
156157
plot_roc_curve(['b', 'a'], [[0.8, 0.2], [0.2, 0.8]])
157158

158159

160+
class TestPlotROC(unittest.TestCase):
161+
def setUp(self):
162+
np.random.seed(0)
163+
self.X, self.y = load_data(return_X_y=True)
164+
p = np.random.permutation(len(self.X))
165+
self.X, self.y = self.X[p], self.y[p]
166+
167+
def tearDown(self):
168+
plt.close("all")
169+
170+
def test_string_classes(self):
171+
np.random.seed(0)
172+
clf = LogisticRegression()
173+
clf.fit(self.X, convert_labels_into_string(self.y))
174+
probas = clf.predict_proba(self.X)
175+
plot_roc(convert_labels_into_string(self.y), probas)
176+
177+
def test_ax(self):
178+
np.random.seed(0)
179+
clf = LogisticRegression()
180+
clf.fit(self.X, self.y)
181+
probas = clf.predict_proba(self.X)
182+
fig, ax = plt.subplots(1, 1)
183+
out_ax = plot_roc(self.y, probas)
184+
assert ax is not out_ax
185+
out_ax = plot_roc(self.y, probas, ax=ax)
186+
assert ax is out_ax
187+
188+
def test_cmap(self):
189+
np.random.seed(0)
190+
clf = LogisticRegression()
191+
clf.fit(self.X, self.y)
192+
probas = clf.predict_proba(self.X)
193+
plot_roc(self.y, probas, cmap='nipy_spectral')
194+
plot_roc(self.y, probas, cmap=plt.cm.nipy_spectral)
195+
196+
def test_plot_micro(self):
197+
np.random.seed(0)
198+
clf = LogisticRegression()
199+
clf.fit(self.X, self.y)
200+
probas = clf.predict_proba(self.X)
201+
plot_roc(self.y, probas, plot_micro=False)
202+
plot_roc(self.y, probas, plot_micro=True)
203+
204+
def test_plot_macro(self):
205+
np.random.seed(0)
206+
clf = LogisticRegression()
207+
clf.fit(self.X, self.y)
208+
probas = clf.predict_proba(self.X)
209+
plot_roc(self.y, probas, plot_macro=False)
210+
plot_roc(self.y, probas, plot_macro=True)
211+
212+
def test_classes_to_plot(self):
213+
np.random.seed(0)
214+
clf = LogisticRegression()
215+
clf.fit(self.X, self.y)
216+
probas = clf.predict_proba(self.X)
217+
plot_roc(self.y, probas, classes_to_plot=[0, 1])
218+
plot_roc(self.y, probas, classes_to_plot=np.array([0, 1]))
219+
220+
def test_array_like(self):
221+
plot_roc([0, 'a'], [[0.8, 0.2], [0.2, 0.8]])
222+
plot_roc([0, 1], [[0.8, 0.2], [0.2, 0.8]])
223+
plot_roc(['b', 'a'], [[0.8, 0.2], [0.2, 0.8]])
224+
225+
159226
class TestPlotKSStatistic(unittest.TestCase):
160227
def setUp(self):
161228
np.random.seed(0)
@@ -292,9 +359,8 @@ def test_plot_micro(self):
292359
clf = LogisticRegression()
293360
clf.fit(self.X, self.y)
294361
probas = clf.predict_proba(self.X)
295-
ax_micro = plot_precision_recall(self.y, probas, plot_micro=True)
296-
ax_class = plot_precision_recall(self.y, probas, plot_micro=False)
297-
self.assertNotEqual(ax_micro, ax_class)
362+
plot_precision_recall(self.y, probas, plot_micro=True)
363+
plot_precision_recall(self.y, probas, plot_micro=False)
298364

299365
def test_cmap(self):
300366
np.random.seed(0)

0 commit comments

Comments
 (0)