-
Notifications
You must be signed in to change notification settings - Fork 241
Open
Description
class MyDecisionTree:
...
def _fit(self, train_data, y, features_indices, feature_labels):
...
# (3)按式(5.10)计算信息增益,并选择信息增益最大的特征
max_feature = 0
max_gda = 0
D = y.copy()
# 计算特征集A中各特征
for feature in features_indices:
# 选择训练集中的第feature列(即第feature个特征)
A = np.array(train_data[:, feature].flat)
# 计算信息增益
gda = self._calc_ent_grap(A, D)
if self._calc_ent(D) != 0:
# 计算信息增益比
gda /= self._calc_ent(D)
# 选择信息增益最大的特征Ag
if gda > max_gda:
max_gda, max_feature = gda, feature
上面代码计算信息增益比,我觉得分母应该是self._calc_ent(A),请确认下。
Metadata
Metadata
Assignees
Labels
No labels