@@ -909,7 +909,8 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis',
909909def plot_calibration_curve (y_true , probas_list , clf_names = None , n_bins = 10 ,
910910 title = 'Calibration plots (Reliability Curves)' ,
911911 ax = None , figsize = None , cmap = 'nipy_spectral' ,
912- title_fontsize = "large" , text_fontsize = "medium" ):
912+ title_fontsize = "large" , text_fontsize = "medium" ,
913+ pos_label = None ):
913914 """Plots calibration curves for a set of classifier probability estimates.
914915
915916 Plotting the calibration curves of a classifier is useful for determining
@@ -937,7 +938,7 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10,
937938 data.
938939
939940 title (string, optional): Title of the generated plot. Defaults to
940- "Calibration plots (Reliabilirt Curves)"
941+ "Calibration plots (Reliability Curves)".
941942
942943 ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
943944 plot the curve. If None, the plot is drawn on a new set of axes.
@@ -958,11 +959,21 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10,
958959 Use e.g. "small", "medium", "large" or integer-values. Defaults to
959960 "medium".
960961
962+ pos_label (int, float, bool, str, optional): The positive label for binary
963+ classification. If `None`, the positive label is inferred from `y_true`.
964+ If `y_true` contains string labels or labels other than {0, 1} or {-1, 1},
965+ you must specify this parameter explicitly.
966+
961967 Returns:
962968 :class:`matplotlib.axes.Axes`: The axes on which the plot was drawn.
963969
964970 Example:
965971 >>> import scikitplot as skplt
972+ >>> from sklearn.ensemble import RandomForestClassifier
973+ >>> from sklearn.linear_model import LogisticRegression
974+ >>> from sklearn.naive_bayes import GaussianNB
975+ >>> from sklearn.svm import LinearSVC
976+ >>> from sklearn.metrics import calibration_curve
966977 >>> rf = RandomForestClassifier()
967978 >>> lr = LogisticRegression()
968979 >>> nb = GaussianNB()
@@ -976,7 +987,8 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10,
976987 ... 'Gaussian Naive Bayes', 'Support Vector Machine']
977988 >>> skplt.metrics.plot_calibration_curve(y_test,
978989 ... probas_list,
979- ... clf_names)
990+ ... clf_names,
991+ ... pos_label='1')
980992 <matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
981993 >>> plt.show()
982994
@@ -1022,7 +1034,7 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10,
10221034 probas = (probas - probas .min ()) / (probas .max () - probas .min ())
10231035
10241036 fraction_of_positives , mean_predicted_value = \
1025- calibration_curve (y_true , probas , n_bins = n_bins )
1037+ calibration_curve (y_true , probas , n_bins = n_bins , pos_label = pos_label )
10261038
10271039 color = plt .cm .get_cmap (cmap )(float (i ) / len (probas_list ))
10281040
0 commit comments