从零开始搭建决策树——手撕CART算法(C++)
CART决策树的基本原理见CART决策树原理。
本文的C++代码基于C++ 20标准(不包含C++ modules),对于之前的标准,可能需要做一些适配。
CART分类树和回归树的内容各自在一个类中,分类树为CartClassifier类,回归树为CartRegression类。
数据结构设计
二叉树设计
// 二叉树结点
struct BinTreeNode
{
std::string threshold_str_;
double threshold_ = -1;
std::string feature_name_;
std::shared_ptr<BinTreeNode> left_ = nullptr;
std::shared_ptr<BinTreeNode> right_ = nullptr;
[[nodiscard]] std::shared_ptr<BinTreeNode> copy() const
{
auto node = std::make_shared<BinTreeNode>();
node->threshold_ = threshold_;
node->threshold_str_ = threshold_str_;
node->feature_name_ = feature_name_;
if (left_)
node->left_ = left_->copy();
if (right_)
node->right_ = right_->copy();
return node;
}
};
copy模块用于二叉树结点的深复制,包括复制本身及其所有的子结点。
结点信息设计
struct Info
{
std::shared_ptr<BinTreeNode> tree_;
size_t num_leaf_ = 0;
double a = 0;
std::pair<bool, std::string> key_str_{};
std::pair<bool, double> key_{};
};
实际上结点信息可以直接存储到二叉树结点BinTreeNode中。分开是为了保证代码的语义清晰,易于理解。
分类树
训练
/**
* @brief
* 训练决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @param feature_names 属性名
* @return 生成的决策树
*/
shared_ptr<BinTreeNode> CartClassifier::train(const vector<vector<string>>& X, const vector<string>& y, const vector<string>& feature_names)
{
feature_names_ = feature_names;
// 创建CART决策树
tree_ = create_tree(X, y);
return tree_;
}
训练函数通过传递常量引用形参,防止训练集和属性集被篡改。如需要修改,可以在函数内部设置副本,针对副本进行修改。create_tree是创建CART分类树的核心函数。
/**
* @brief
* 创建树
* @param X 训练集属性值
* @param y 训练集目标变量
* @return 训练好的决策树
*/
shared_ptr<BinTreeNode> CartClassifier::create_tree(const vector<vector<string>>& X, const vector<string>& y)
{
// 若X中样本全属于同一类别C,则停止划分
auto tree = make_shared<BinTreeNode>();
if (unordered_set(y.begin(), y.end()).size() == 1)
{
tree->threshold_str_ = y.front();
return tree;
}
// 若节点样本数小于min_samples_split,或者属性集上的取值均相同
if (y.size() <= min_samples_split_ || set(X.begin(), X.end()).size() == 1)
{
tree->threshold_str_ = majority_y(y);
return tree;
}
// 按照“基尼增益”,从属性值中选择最优分裂属性的最优切分点
auto [best_split_point, best_feature_index] = choose_best_point_to_split(X, y);
const string best_feature_name = feature_names_[best_feature_index];
// 根据最优切分点,进行子树的划分
vector<vector<string>> sub_X1, sub_X2;
vector<string> sub_y1, sub_y2;
for (int i = 0; i < X.size(); i++)
if (X[i][best_feature_index] == best_split_point)
{
sub_X1.emplace_back(X[i]);
sub_y1.emplace_back(y[i]);
}
else
{
sub_X2.emplace_back(X[i]);
sub_y2.emplace_back(y[i]);
}
tree->feature_name_ = best_feature_name;
tree->threshold_str_ = best_split_point;
tree->left_ = create_tree(sub_X1, sub_y1);
tree->right_ = create_tree(sub_X2, sub_y2);
return tree;
}
create_tree函数是一个递归创建决策树的过程。首先判断三种递归中止条件:
X中样本全部属于同一类别;- 当前节点样本数小于
min_samples_split_; - 属性集上的取值均相同
若满足终止条件,则选择中最多的类别作为结果返回。若未满足终止条件,依次执行以下步骤:
- 根据基尼指数从属性值中选择最优分裂属性的最优切分点,具体过程如
choose_best_point_to_split函数所示; - 根据最优切分点对子树进行划分;
- 对于其子树再继续执行
create_tree函数完成划分过程。
/**
* @brief
* 统计每个类别出现的次数,返回出现次数最大的类别ID
* @param y 目标变量集合
* @return 出现次数最大的类别
*/
string CartClassifier::majority_y(const vector<string>& y)
{
// 统计y中的目标变量值的个数
unordered_map<string, int> y_count;
for (const string& v : y)
{
if (!y_count.contains(v))
y_count[v] = 0;
++y_count[v];
}
return ranges::max_element(y_count, [](const pair<string, int>& a, const pair<string, int>& b) { return a.second < b.second; })->first;
}
majority_y用于计算节点中出现次数最多的类别。包含以下步骤:
- 初始化一个空映射;
- 遍历
y并对其元素进行计数; - 从映射中查找出现次数最多的类别。
/**
* @brief
* 选择最优切分点
* @param X 训练集属性值
* @param y 训练集目标变量
* @return 最优切分点和最优切分点所在属性的索引
*/
pair<string, int> CartClassifier::choose_best_point_to_split(const vector<vector<string>>& X, const vector<string>& y)
{
string best_split_point;
int best_feature_index = -1;
double best_gini_index = numeric_limits<double>::infinity();
const size_t num_feature = X[0].size(); // 属性的个数
for (int i = 0; i < num_feature; i++) // 遍历每个属性
{
// 得到某个属性下的所有值,即某列,并去重,得到无重复的属性特征值
unordered_set<string> split_points;
for (const vector<string>& x : X)
split_points.emplace(x[i]);
for (const string& split_point : split_points) // 计算各个候选切分点的基尼不纯度
{
vector<string> sub_y_left, sub_y_right;
for (int j = 0; j < X.size(); j++)
if (X[j][i] == split_point)
sub_y_left.emplace_back(y[j]);
else
sub_y_right.emplace_back(y[j]);
// 计算左子树的基尼不纯度
const double gini_impurity_left = cal_gini_impurity(sub_y_left);
// 计算右子树的基尼不纯度
const double gini_impurity_right = cal_gini_impurity(sub_y_right);
// 计算该切分点的基尼指数
const double pro_left = static_cast<double>(sub_y_left.size()) / static_cast<double>(y.size()), pro_right = static_cast<double>(sub_y_right.size()) / static_cast<double>(y.size());
if (const double gini_index = cal_gini_index(pro_left, pro_right, gini_impurity_left, gini_impurity_right); best_gini_index > gini_index) // 取基尼指数最大的属性索引和切分点
{
best_gini_index = gini_index;
best_feature_index = i;
best_split_point = split_point;
}
}
}
return {best_split_point, best_feature_index};
}
choose_best_point_to_split是CART分类树中最核心的函数,该函数负责选择最优切分点。根据前面的理论推导,该函数的目的是计算取得最大基尼增益的属性值。该函数遍历每个属性的每个属性值,根据是否等于属性值(二分类问题)将数据集分割到左右子树,依次计算左右子树的基尼不纯度和
,以及左右子树中数据样本在总样本中占的比例
和
,并且将
代入
cal_gini_index函数中计算基尼指数。最后选出具有最小基尼指数的属性值,作为当前节点的最优切分点,并返回最优切分点和最优分裂属性索引。
/**
* @brief
* 计算数据集的基尼不纯度
* @param y 目标变量集合
* @return 基尼不纯度 double
*/
double CartClassifier::cal_gini_impurity(const vector<string>& y)
{
// 统计y中的目标变量值的个数
unordered_map<string, int> y_count;
for (const string& v : y)
{
if (!y_count.contains(v))
y_count[v] = 0;
++y_count[v];
}
// 计算基尼不纯度
double gini_impurity = 1;
const auto num_samples = static_cast<double>(y.size());
for (const int& k : y_count | views::values)
{
const double prob = k / num_samples;
gini_impurity -= prob * prob;
}
return gini_impurity;
}
cal_gini_impurity用于计算基尼不纯度,包含以下步骤:
- 分析导入的数据集的最后一列(一般默认为数据类别),根据不同类别按出现次数统计到分类字典中;
- 遍历该字典,根据公式用1减去不同的类分布概率的平方和,得到最终的基尼不纯度。
/**
* @brief
* 计算基尼指数
* @param pro_left 左子树比例
* @param pro_right 右子树比例
* @param gini_impurity_left 左子树的基尼不纯度
* @param gini_impurity_right 右子树的基尼不纯度
* @return 基尼指数 double
*/
double CartClassifier::cal_gini_index(const double pro_left, const double pro_right, const double gini_impurity_left, const double gini_impurity_right)
{
return pro_left * gini_impurity_left + pro_right * gini_impurity_right;
}
cal_gini_index通过公式计算基尼指数。
预测
/**
* @brief
* 使用决策树进行预测
* @param X 测试集属性值
* @return 预测值
*/
vector<string> CartClassifier::predict(const vector<vector<string>>& X)
{
vector<string> y_preds;
for (const vector<string>& x : X)
y_preds.emplace_back(classify(tree_, x));
return y_preds;
}
遍历测试集X的每个样本,使用classify函数分别对其进行预测,最终返回拼接好的预测结果。
/**
* @brief
* 分类预测
* @param tree 训练好的CART树
* @param x 待分类样本
* @return 预测类
*/
string CartClassifier::classify(const shared_ptr<BinTreeNode>& tree, const vector<string>& x)
{
const string& first_str = tree->feature_name_; // 根节点
const size_t feature_index = distance(feature_names_.begin(), ranges::find(feature_names_, first_str));
const string& current_value = x[feature_index];
if (tree->left_ && current_value == tree->threshold_str_)
return classify(tree->left_, x);
if (tree->right_ && current_value != tree->threshold_str_)
return classify(tree->right_, x);
return tree->threshold_str_;
}
通过调用classify进行预测分类。参数tree的根节点代表属性,根节点的左右孩子节点代表属性的取值及路由方向。在递归遍历过程中,从根节点开始,递归遍历CART分类树,最终路由到某个叶子节点,叶子节点上的值即为该决策树的预测结果。
剪枝
/**
* @brief
* 代价复杂度剪枝CCP
* @param X 训练集属性值
* @param y 训练集目标变量
* @return 剪枝后的决策树集合
*/
vector<shared_ptr<BinTreeNode>> CartClassifier::pruning(const vector<vector<string>>& X, const vector<string>& y)
{
// 递归计算对当前树的每个子树的g(ti),挑选最小的g(ti)进行剪枝,得到新的T,最终得到n个T
return split_n_best_trees(X, y);
}
函数pruning根据不同的区间生成不同剪枝程度的决策树集合。集合中越后面的决策树,剪枝程度越高。
/**
* @brief
* 根据g(ti)生成n个误差最小的树
* @param X 训练集属性值
* @param y 训练集目标变量
* @return n个误差最小的树
*/
vector<shared_ptr<BinTreeNode>> CartClassifier::split_n_best_trees(const vector<vector<string>>& X, const vector<string>& y)
{
vector<shared_ptr<BinTreeNode>> trees;
shared_ptr<BinTreeNode> tree = tree_->copy();
while (tree)
if (shared_ptr<BinTreeNode> best_tree = split_1_best_trees(tree, X, y))
{
trees.emplace_back(best_tree);
tree = best_tree->copy();
}
else
tree = nullptr;
return trees;
}
split_n_best_trees函数通过调用split_1_best_trees函数递归生成棵预测误差最小的树,每一次递归的初始树均为上一次递归得到的最优剪枝树。为了在递归过程中不破坏上一轮得到的最优剪枝树,使用了深拷贝。
/**
* @brief
* 计算α值,选出α值最小的剪枝树
* @param tree 决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @return α值最小的剪枝树
*/
shared_ptr<BinTreeNode> CartClassifier::split_1_best_trees(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X, const vector<string>& y)
{
// 构建节点信息总集合
vector<Info> infoSet;
// 计算数据集长度
const size_t NT = X.size();
// 计算误差增加率,并生成信息集合
calErrorRatio(tree, X, y, NT, infoSet);
if (infoSet.empty())
return nullptr;
// a的比较基准值
double baseValue = 1;
int bestNode = 0;
for (int i = 0; i < infoSet.size(); i++)
if (infoSet[i].a < baseValue)
{
baseValue = infoSet[i].a;
bestNode = i;
}
else if (infoSet[i].a == baseValue && infoSet[i].num_leaf_ > infoSet[bestNode].num_leaf_)
bestNode = i;
return prunBranch(tree, X, y, infoSet[bestNode]);
}
函数split_1_best_tree负责递归计算值,并且选出
值最小的剪枝树。当前树的深度大于1时,开始进行CCP的迭代剪枝。在每次迭代内部,对每个分支节点进行
的计算,并选取最小值对应的子树进行剪枝。如果求得的最小
对应的子树有多个,则优先选取节点数目最多的子树作为修剪的对象。
/**
* @brief
* 计算非叶节点误差增加率
* @param tree 决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @param NT 数据集总样本数目
* @param infoSet 所有节点的信息总集合
* @return 各个节点的信息集
*/
Info CartClassifier::calErrorRatio(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X, const vector<string>& y, const size_t NT, vector<Info>& infoSet)
{
const string_view firstFeat = tree->feature_name_;
const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
if (tree->left_ && (tree->left_->left_ || tree->left_->right_))
{
// 划分数据集
vector<vector<string>> sub_X;
vector<string> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] == tree->threshold_str_)
{
// 取第i行进subData
// 相当于把label特征取值剔除,将其他特征取值输出
// 将每个符合条件的特征列表,组成列表集合
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
Info info = calErrorRatio(tree->left_, sub_X, sub_y, NT, infoSet);
// 在节点信息集中,增加分类前特征
info.key_str_ = {true, tree->threshold_str_};
infoSet.emplace_back(info);
}
if (tree->right_ && (tree->right_->left_ || tree->right_->right_))
{
// 划分数据集
vector<vector<string>> sub_X;
vector<string> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] != tree->threshold_str_)
{
// 取第i行进subData
// 相当于把label特征取值剔除,将其他特征取值输出
// 将每个符合条件的特征列表,组成列表集合
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
Info info = calErrorRatio(tree->right_, sub_X, sub_y, NT, infoSet);
// 在节点信息集中,增加分类前特征
info.key_str_ = {false, tree->threshold_str_};
infoSet.emplace_back(info);
}
// 计算节点误差率
const double Ct = static_cast<double>(nodeError(y)) / static_cast<double>(NT);
// 计算子树误差率
const double CTt = static_cast<double>(leafError(tree, X, y)) / static_cast<double>(NT);
// 计算叶节点数目
const size_t Nt = getNumLeaf(tree);
const double a = Nt == 1 ? 2 : (Ct - CTt) / static_cast<double>(Nt - 1);
return {tree, Nt, a};
}
每次迭代中的计算,也就是
calErrorRatio函数。该函数主要计算节点的误差率
、节点
对应子树
的误差率
、子树叶子节点的数目
。
的计算采用递归的方法,最终将所有
info合并成节点信息集合。
/**
* @brief
* 计算非叶节点的误差
* @param y 训练集目标变量
* @return 误差
*/
size_t CartClassifier::nodeError(const vector<string>& y)
{
// 找到数量最多的类别
string majorClass = majority_y(y);
// 游历数据集每个元素,找出正确样本个数,如果不一致,错误加1
return ranges::count_if(y, [&majorClass](const string& v) { return v != majorClass; });
}
/**
* @brief
* 计算叶节点的误差
* @param tree 生成的决策树
* @param X
* @param y 训练集目标变量
* @return 误差
*/
size_t CartClassifier::leafError(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X, const vector<string>& y)
{
size_t error = 0;
for (int i = 0; i < X.size(); i++)
if (classify(tree, X[i]) != y[i])
++error;
return error;
}
/**
* @brief
* 获取叶节点数量
* @param tree 决策树
* @return 返回树的叶节点
*/
size_t CartClassifier::getNumLeaf(const shared_ptr<BinTreeNode>& tree)
{
size_t numLeafs = 0;
if (tree->left_)
numLeafs += getNumLeaf(tree->left_);
if (tree->right_)
numLeafs += getNumLeaf(tree->right_);
if (!tree->left_ && !tree->right_)
++numLeafs;
return numLeafs;
}
/**
* @brief
* 根据误差增加率,剪掉子树
* @param tree 决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @param infoBran 需剪掉的子树信息集
* @return 剪枝后的决策树
*/
shared_ptr<BinTreeNode> CartClassifier::prunBranch(const shared_ptr<BinTreeNode>& tree, const vector<vector<string>>& X, const vector<string>& y, const Info& infoBran)
{
const string_view firstFeat = tree->feature_name_;
const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
if (tree->left_)
{
// 划分数据集
vector<vector<string>> sub_X;
vector<string> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] == tree->threshold_str_)
{
// 取第i行进subData
// 相当于把label特征取值剔除,将其他特征取值输出
// 将每个符合条件的特征列表,组成列表集合
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
// 找到数量最多的类别
const string majorClass = majority_y(sub_y);
// 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
if (infoBran.key_str_.first && infoBran.key_str_.second == tree->threshold_str_ && tree->left_ == infoBran.tree_)
{
// 剪掉子树,即返回最大类
tree->left_ = make_shared<BinTreeNode>();
tree->left_->threshold_str_ = majorClass;
return tree;
}
// 如果不相同,继续向下寻找
tree->left_ = prunBranch(tree->left_, sub_X, sub_y, infoBran);
}
if (tree->right_)
{
// 划分数据集
vector<vector<string>> sub_X;
vector<string> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] != tree->threshold_str_)
{
// 取第i行进subData
// 相当于把label特征取值剔除,将其他特征取值输出
// 将每个符合条件的特征列表,组成列表集合
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
// 找到数量最多的类别
const string majorClass = majority_y(sub_y);
// 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
if (!infoBran.key_str_.first && infoBran.key_str_.second == tree->threshold_str_ && tree->right_ == infoBran.tree_)
{
// 剪掉子树,即返回最大类
tree->right_ = make_shared<BinTreeNode>();
tree->right_->threshold_str_ = majorClass;
return tree;
}
// 如果不相同,继续向下寻找
tree->right_ = prunBranch(tree->right_, sub_X, sub_y, infoBran);
}
return tree;
}
应用注意事项
- 为方便理解,代码仅考虑了离散字符串的分类,并未考虑其他离散值和连续值的分类,实际生产过程可能需要补充;
- 原则上来说,代码数据集中的字符串均需要通过编码(分类算法中编码无限制),以提升效率。为方便理解,本文章使用原始字符串,不影响结果;
- 代码中的
feature_name仅作画图需要,实际生产如无该需求,可以去掉该变量; - 对于CCP误差的计算,scikit-learn使用基尼不纯度进行代替,因其不用每次使用预测计算,提高了效率。但基尼不纯度与误差之间仅具有相关性,无法通过基尼不纯度推导出误差,仅用作近似计算;
- 代码未考虑缺失值的处理;
- 代码没有适配多线程场景;
- 其他可能的算法时空复杂度的优化。
回归树
训练
CartRegressor的创建和训练过程与CartClassifier类似。最重要的区别在于模型训练时切分点的选取。
/**
* @brief
* 训练决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @param feature_names 属性名
* @return 生成的决策树
*/
shared_ptr<BinTreeNode> CartRegressor::train(const vector<vector<double>>& X, const vector<double>& y, const vector<string>& feature_names)
{
feature_names_ = feature_names;
tree_ = create_tree(X, y);
return tree_;
}
train的入参类型发生了变化,这是因为回归树使用的是连续类型数据。
/**
* @brief
* 创建树
* @param X 映射后的数值属性集
* @param y 映射后的数值目标变量集
* @return 训练好的决策树
*/
shared_ptr<BinTreeNode> CartRegressor::create_tree(const vector<vector<double>>& X, const vector<double>& y)
{
// 若X中样本全属于同一类别C,则停止划分
auto tree = make_shared<BinTreeNode>();
if (unordered_set(y.begin(), y.end()).size() == 1)
{
tree->threshold_ = y.front();
return tree;
}
// 若节点样本数小于min_samples_split,或者属性集上的取值均相同
if (y.size() <= min_samples_split_ || set(X.begin(), X.end()).size() == 1)
{
tree->threshold_ = accumulate(y.begin(), y.end(), 0.) / static_cast<double>(y.size());
return tree;
}
// 按照“平方误差最小”,从feature_names中选择最优切分点
auto [best_split_point, best_feature_index] = choose_best_point_to_split(X, y);
const string_view best_feature_name = feature_names_[best_feature_index];
// 根据最优切分点,进行子树的划分
vector<vector<double>> sub_X1, sub_X2;
vector<double> sub_y1, sub_y2;
for (int i = 0; i < X.size(); i++)
if (X[i][best_feature_index] <= best_split_point)
{
sub_X1.emplace_back(X[i]);
sub_y1.emplace_back(y[i]);
}
else
{
sub_X2.emplace_back(X[i]);
sub_y2.emplace_back(y[i]);
}
tree->feature_name_ = best_feature_name;
tree->threshold_ = best_split_point;
tree->left_ = create_tree(sub_X1, sub_y1);
tree->right_ = create_tree(sub_X2, sub_y2);
return tree;
}
在函数create_tree中,主要有3处与分类树不同:
- 当满足递归终止条件“节点样本数小于
min_samples_split_”时,返回的预测值是该集合中所有目标变量的平均值; - 在
choose_best_point_to_split函数中,在回归树中采用“平方误差最小”的原则来选择最优切分点; - 使用最优属性和最优切分点划分数据集时相较分类树(处理匹配字符串“
”和“
”的代码逻辑)做略微调整。
/**
* @brief
* 选择最优切分点
* @param X 映射后的数值属性集
* @param y 属性名称
* @return 最优切分点和最优切分点所在属性的索引
*/
pair<double, int> CartRegressor::choose_best_point_to_split(const vector<vector<double>>& X, const vector<double>& y)
{
double best_split_point = 0, best_loss_all = numeric_limits<double>::infinity();
int best_feature_index = -1;
const size_t num_feature = X[0].size(); // 属性的个数
for (int i = 0; i < num_feature; ++i) // 遍历每个属性
{
// 得到某个属性下的所有值,即某列,并去重,得到无重复的属性特征值
set<double> unique_feature_value;
vector<double> split_points;
for (const vector<double>& x : X)
unique_feature_value.emplace(x[i]);
auto lit = unique_feature_value.begin(), rit = lit;
++rit;
while (rit != unique_feature_value.end())
{
split_points.emplace_back((*lit + *rit) / 2);
++lit;
++rit;
}
// 计算各个候选切分点的损失函数
for (const double split_point : split_points)
{
vector<double> sub_y_left, sub_y_right;
for (int j = 0; j < X.size(); j++)
if (X[j][i] <= split_point)
sub_y_left.emplace_back(y[j]);
else
sub_y_right.emplace_back(y[j]);
const double sub_y_left_mean = accumulate(sub_y_left.begin(), sub_y_left.end(), 0.) / static_cast<double>(sub_y_left.size()), sub_y_right_mean = accumulate(sub_y_right.begin(), sub_y_right.end(), 0.) / static_cast<double>(sub_y_right.size());
double loss_left = 0, loss_right = 0;
// 计算左子树的损失函数
for (const double j : sub_y_left)
loss_left += pow(j - sub_y_left_mean, 2);
// 计算右子树的损失函数
for (const double j : sub_y_right)
loss_right += pow(j - sub_y_right_mean, 2);
// 计算该切分点的总损失函数
// 取损失函数最小时的属性索引和切分点
if (const double loss_all = loss_left + loss_right; best_loss_all > loss_all)
{
best_loss_all = loss_all;
best_feature_index = i;
best_split_point = split_point;
}
}
}
return {best_split_point, best_feature_index};
}
choose_best_point_to_split遍历所有属性值时,回归树中不再计算基尼不纯度和基尼增益,而是针对回归问题计算损失函数。分别计算了使用当前切分点划分的左右子树的残差平方和,再计算左右子树的总残差平方和。最后选出取得最小损失函数的切分点和属性索引,作为最优切分点和最优分裂属性。
预测
/**
* @brief
* 使用决策树进行预测
* @param X 测试集属性值
* @return 预测值
*/
vector<double> CartRegressor::predict(const vector<vector<double>>& X)
{
vector<double> y_preds;
for (const vector<double>& x : X)
y_preds.emplace_back(regression(tree_, x));
return y_preds;
}
/**
* @brief
* 回归预测
* @param tree 训练好的树
* @param x 待分类样本
* @return 预测类
*/
double CartRegressor::regression(const shared_ptr<BinTreeNode>& tree, const vector<double>& x)
{
const string& first_str = tree->feature_name_; // 根节点
const size_t feature_index = distance(feature_names_.begin(), ranges::find(feature_names_, first_str));
const double current_value = x[feature_index];
if (tree->left_ && current_value <= tree->threshold_)
return regression(tree->left_, x);
if (tree->right_ && current_value > tree->threshold_)
return regression(tree->right_, x);
return tree->threshold_;
}
由于CART回归树与分类树的预测过程几乎完全相同,在此不做赘述。
剪枝
/**
* @brief
* 代价复杂度剪枝CCP
* @param X 训练集属性值
* @param y 训练集目标变量
* @return 剪枝后的决策树集合
*/
vector<shared_ptr<BinTreeNode>> CartRegressor::pruning(const vector<vector<double>>& X, const vector<double>& y)
{
// 递归计算对当前树的每个子树的g(ti),挑选最小的g(ti)进行剪枝,得到新的T,最终得到n个T
return split_n_best_trees(X, y);
}
/**
* @brief
* 根据g(ti)生成n个误差最小的树
* @param X 训练集属性值
* @param y 训练集目标变量
* @return n个误差最小的树
*/
vector<shared_ptr<BinTreeNode>> CartRegressor::split_n_best_trees(const vector<vector<double>>& X, const vector<double>& y)
{
vector<shared_ptr<BinTreeNode>> trees;
shared_ptr<BinTreeNode> tree = tree_->copy();
while (tree)
if (shared_ptr<BinTreeNode> best_tree = split_1_best_trees(tree, X, y))
{
trees.emplace_back(best_tree);
tree = best_tree->copy();
}
else
tree = nullptr;
return trees;
}
/**
* @brief
* 计算α值,选出α值最小的剪枝树
* @param tree 决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @return α值最小的剪枝树
*/
shared_ptr<BinTreeNode> CartRegressor::split_1_best_trees(const shared_ptr<BinTreeNode>& tree, const vector<vector<double>>& X, const vector<double>& y)
{
// 构建节点信息总集合
vector<Info> infoSet;
// 计算数据集长度
const size_t NT = X.size();
// 计算误差增加率,并生成信息集合
calErrorRatio(tree, X, y, NT, infoSet);
if (infoSet.empty())
return nullptr;
// a的比较基准值
double baseValue = 1;
int bestNode = 0;
for (int i = 0; i < infoSet.size(); i++)
if (infoSet[i].a < baseValue)
{
baseValue = infoSet[i].a;
bestNode = i;
}
else if (infoSet[i].a == baseValue && infoSet[i].num_leaf_ > infoSet[bestNode].num_leaf_)
bestNode = i;
return prunBranch(tree, X, y, infoSet[bestNode]);
}
/**
* @brief
* 计算非叶节点误差增加率
* @param tree 决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @param NT 数据集总样本数目
* @param infoSet 所有节点的信息总集合
* @return 各个节点的信息集
*/
Info CartRegressor::calErrorRatio(const shared_ptr<BinTreeNode>& tree, const vector<vector<double>>& X, const vector<double>& y, size_t NT, vector<Info>& infoSet)
{
const string_view firstFeat = tree->feature_name_;
const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
if (tree->left_ && (tree->left_->left_ || tree->left_->right_))
{
// 划分数据集
vector<vector<double>> sub_X;
vector<double> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] <= tree->threshold_)
{
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
Info info = calErrorRatio(tree->left_, sub_X, sub_y, NT, infoSet);
// 在节点信息集中,增加分类前特征
info.key_ = {true, tree->threshold_};
infoSet.emplace_back(info);
}
if (tree->right_ && (tree->right_->left_ || tree->right_->right_))
{
// 划分数据集
vector<vector<double>> sub_X;
vector<double> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] > tree->threshold_)
{
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
Info info = calErrorRatio(tree->right_, sub_X, sub_y, NT, infoSet);
// 在节点信息集中,增加分类前特征
info.key_ = {false, tree->threshold_};
infoSet.emplace_back(info);
}
// 计算节点误差率
const double Rt = static_cast<double>(nodeError(y)) / static_cast<double>(NT);
// 计算子树误差率
const double RTt = static_cast<double>(leafError(tree, X, y)) / static_cast<double>(NT);
// 计算叶节点数目
const size_t Nt = getNumLeaf(tree);
const double a = Nt == 1 ? 2 : (Rt - RTt) / static_cast<double>(Nt - 1);
return {tree, Nt, a};
}
/**
* @brief
* 计算非叶节点的误差
* @param y 训练集目标变量
* @return 误差
*/
size_t CartRegressor::nodeError(const vector<double>& y)
{
// 计算节点的平方误差
const double mean_y = accumulate(y.begin(), y.end(), 0.) / static_cast<double>(y.size());
size_t error = 0;
for (const double& val : y)
error += static_cast<size_t>(pow(val - mean_y, 2));
return error;
}
/**
* @brief
* 计算叶节点的误差
* @param tree 生成的决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @return 误差
*/
size_t CartRegressor::leafError(const shared_ptr<BinTreeNode>& tree, const vector<vector<double>>& X, const vector<double>& y)
{
size_t error = 0;
for (int i = 0; i < X.size(); i++)
{
const double pred = regression(tree, X[i]);
error += static_cast<size_t>(pow(pred - y[i], 2));
}
return error;
}
/**
* @brief
* 获取叶节点数量
* @param tree 决策树
* @return 返回树的叶节点
*/
size_t CartRegressor::getNumLeaf(const shared_ptr<BinTreeNode>& tree)
{
size_t numLeafs = 0;
if (tree->left_)
numLeafs += getNumLeaf(tree->left_);
if (tree->right_)
numLeafs += getNumLeaf(tree->right_);
if (!tree->left_ && !tree->right_)
++numLeafs;
return numLeafs;
}
/**
* @brief
* 根据误差增加率,剪掉子树
* @param tree 决策树
* @param X 训练集属性值
* @param y 训练集目标变量
* @param infoBran 需剪掉的子树信息集
* @return 剪枝后的决策树
*/
shared_ptr<BinTreeNode> CartRegressor::prunBranch(const shared_ptr<BinTreeNode>& tree, const vector<vector<double>>& X, const vector<double>& y, const Info& infoBran)
{
const string_view firstFeat = tree->feature_name_;
const size_t labelIndex = distance(feature_names_.begin(), ranges::find(feature_names_, firstFeat));
if (tree->left_)
{
// 划分数据集
vector<vector<double>> sub_X;
vector<double> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] <= tree->threshold_)
{
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
// 计算该分支的平均值
const double mean_val = accumulate(sub_y.begin(), sub_y.end(), 0.) / static_cast<double>(sub_y.size());
// 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
if (infoBran.key_.first && abs(infoBran.key_.second - tree->threshold_) < 1e-9 && tree->left_ == infoBran.tree_)
{
// 剪掉子树,即返回平均值
tree->left_ = make_shared<BinTreeNode>();
tree->left_->threshold_ = mean_val;
return tree;
}
// 如果不相同,继续向下寻找
tree->left_ = prunBranch(tree->left_, sub_X, sub_y, infoBran);
}
if (tree->right_)
{
// 划分数据集
vector<vector<double>> sub_X;
vector<double> sub_y;
for (int i = 0; i < X.size(); i++)
if (X[i][labelIndex] > tree->threshold_)
{
sub_X.emplace_back(X[i]);
sub_y.emplace_back(y[i]);
}
// 计算该分支的平均值
const double mean_val = accumulate(sub_y.begin(), sub_y.end(), 0.) / static_cast<double>(sub_y.size());
// 如果当前子树分类前特征和子树都和预处理相同,则把该子树剪掉
if (!infoBran.key_.first && abs(infoBran.key_.second - tree->threshold_) < 1e-9 && tree->right_ == infoBran.tree_)
{
// 剪掉子树,即返回平均值
tree->right_ = make_shared<BinTreeNode>();
tree->right_->threshold_ = mean_val;
return tree;
}
// 如果不相同,继续向下寻找
tree->right_ = prunBranch(tree->right_, sub_X, sub_y, infoBran);
}
return tree;
}
回归树的剪枝与分类树类似,不同点在于回归树计算误差使用的是均方差。
应用注意事项
- 代码中的
feature_name仅作画图需要,实际生产如无该需求,可以去掉该变量; - 代码未考虑缺失值的处理;
- 分类树和回归树中的CCP算法,仅在误差计算中有区别。分类树中可以使用基尼系数或误分类率(从效率层面,推荐使用基尼系数),回归树中使用均方差;
- 代码没有适配多线程场景;
- 其他可能的算法时空复杂度的优化。
查看11道真题和解析