diff --git a/scikitplot/metrics.py b/scikitplot/metrics.py index 08ec693..251cfb0 100644 --- a/scikitplot/metrics.py +++ b/scikitplot/metrics.py @@ -331,7 +331,7 @@ def plot_roc_curve(y_true, y_probas, title='ROC Curves', def plot_roc(y_true, y_probas, title='ROC Curves', plot_micro=True, plot_macro=True, classes_to_plot=None, - ax=None, figsize=None, cmap='nipy_spectral', + show_labels=True, ax=None, figsize=None, cmap='nipy_spectral', title_fontsize="large", text_fontsize="medium"): """Generates the ROC curves from labels and predicted scores/probabilities @@ -351,6 +351,9 @@ def plot_roc(y_true, y_probas, title='ROC Curves', plot_macro (boolean, optional): Plot the macro average ROC curve. Defaults to ``True``. + show_labels (boolean, optional): Shows the labels in the plot. + Defaults to ``True``. + classes_to_plot (list-like, optional): Classes for which the ROC curve should be plotted. e.g. [0, 'cold']. If given class does not exist, it will be ignored. If ``None``, all classes will be plotted. Defaults to @@ -457,7 +460,8 @@ def plot_roc(y_true, y_probas, title='ROC Curves', ax.set_xlabel('False Positive Rate', fontsize=text_fontsize) ax.set_ylabel('True Positive Rate', fontsize=text_fontsize) ax.tick_params(labelsize=text_fontsize) - ax.legend(loc='lower right', fontsize=text_fontsize) + if show_labels is True: + ax.legend(loc='lower right', fontsize=text_fontsize) return ax