cart决策树 案例
- 格式:doc
- 大小:36.72 KB
- 文档页数:2
决策树系列(五)——CARTCART,⼜名分类回归树,是在ID3的基础上进⾏优化的决策树,学习CART记住以下⼏个关键点:(1)CART既能是分类树,⼜能是分类树;(2)当CART是分类树时,采⽤GINI值作为节点分裂的依据;当CART是回归树时,采⽤样本的最⼩⽅差作为节点分裂的依据;(3)CART是⼀棵⼆叉树。
接下来将以⼀个实际的例⼦对CART进⾏介绍: 表1 原始数据表看电视时间婚姻情况职业年龄3未婚学⽣124未婚学⽣182已婚⽼师265已婚上班族472.5已婚上班族363.5未婚⽼师294已婚学⽣21从以下的思路理解CART:分类树?回归树?分类树的作⽤是通过⼀个对象的特征来预测该对象所属的类别,⽽回归树的⽬的是根据⼀个对象的信息预测该对象的属性,并以数值表⽰。
CART既能是分类树,⼜能是决策树,如上表所⽰,如果我们想预测⼀个⼈是否已婚,那么构建的CART将是分类树;如果想预测⼀个⼈的年龄,那么构建的将是回归树。
分类树和回归树是怎么做决策的?假设我们构建了两棵决策树分别预测⽤户是否已婚和实际的年龄,如图1和图2所⽰: 图1 预测婚姻情况决策树图2 预测年龄的决策树图1表⽰⼀棵分类树,其叶⼦节点的输出结果为⼀个实际的类别,在这个例⼦⾥是婚姻的情况(已婚或者未婚),选择叶⼦节点中数量占⽐最⼤的类别作为输出的类别;图2是⼀棵回归树,预测⽤户的实际年龄,是⼀个具体的输出值。
怎样得到这个输出值?⼀般情况下选择使⽤中值、平均值或者众数进⾏表⽰,图2使⽤节点年龄数据的平均值作为输出值。
CART如何选择分裂的属性?分裂的⽬的是为了能够让数据变纯,使决策树输出的结果更接近真实值。
那么CART是如何评价节点的纯度呢?如果是分类树,CART采⽤GINI值衡量节点纯度;如果是回归树,采⽤样本⽅差衡量节点纯度。
节点越不纯,节点分类或者预测的效果就越差。
GINI值的计算公式:节点越不纯,GINI值越⼤。
以⼆分类为例,如果节点的所有数据只有⼀个类别,则,如果两类数量相同,则。
决策树的C++实现(CART)decision_tree.hpp⽂件内容如下:#ifndef FBC_NN_DECISION_TREE_HPP_#define FBC_NN_DECISION_TREE_HPP_#include <vector>#include <tuple>#include <fstream>namespace ANN {// referecne: https:///implement-decision-tree-algorithm-scratch-python/template<typename T>class DecisionTree { // CART(Classification and Regression Trees)public:DecisionTree() = default;~DecisionTree() { delete_tree(); }int init(const std::vector<std::vector<T>>& data, const std::vector<T>& classes);void set_max_depth(int max_depth) { this->max_depth = max_depth; }int get_max_depth() const { return max_depth; }void set_min_size(int min_size) { this->min_size = min_size; }int get_min_size() const { return min_size; }void train();int save_model(const char* name) const;int load_model(const char* name);T predict(const std::vector<T>& data) const;protected:typedef std::tuple<int, T, std::vector<std::vector<std::vector<T>>>> dictionary; // index of attribute, value of attribute, groups of data typedef std::tuple<int, int, T, T, T> row_element; // flag, index, value, class_value_left, class_value_righttypedef struct binary_tree {dictionary dict;T class_value_left = (T)-1.f;T class_value_right = (T)-1.f;binary_tree* left = nullptr;binary_tree* right = nullptr;} binary_tree;// Calculate the Gini index for a split datasetT gini_index(const std::vector<std::vector<std::vector<T>>>& groups, const std::vector<T>& classes) const;// Select the best split point for a datasetdictionary get_split(const std::vector<std::vector<T>>& dataset) const;// Split a dataset based on an attribute and an attribute valuestd::vector<std::vector<std::vector<T>>> test_split(int index, T value, const std::vector<std::vector<T>>& dataset) const;// Create a terminal node valueT to_terminal(const std::vector<std::vector<T>>& group) const;// Create child splits for a node or make terminalvoid split(binary_tree* node, int depth);// Build a decision treevoid build_tree(const std::vector<std::vector<T>>& train);// Print a decision treevoid print_tree(const binary_tree* node, int depth = 0) const;// Make a prediction with a decision treeT predict(binary_tree* node, const std::vector<T>& data) const;// calculate accuracy percentagedouble accuracy_metric() const;void delete_tree();void delete_node(binary_tree* node);void write_node(const binary_tree* node, std::ofstream& file) const;void node_to_row_element(binary_tree* node, std::vector<row_element>& rows, int pos) const;int height_of_tree(const binary_tree* node) const;void row_element_to_node(binary_tree* node, const std::vector<row_element>& rows, int n, int pos);void row_element_to_node(binary_tree* node, const std::vector<row_element>& rows, int n, int pos);private:std::vector<std::vector<T>> src_data;binary_tree* tree = nullptr;int samples_num = 0;int feature_length = 0;int classes_num = 0;int max_depth = 10; // maximum tree depthint min_size = 10; // minimum node recordsint max_nodes = -1;};} // namespace ANN#endif // FBC_NN_DECISION_TREE_HPP_decision_tree.cpp⽂件内容如下:#include "decision_tree.hpp"#include <set>#include <algorithm>#include <typeinfo>#include <iterator>#include "common.hpp"namespace ANN {template<typename T>int DecisionTree<T>::init(const std::vector<std::vector<T>>& data, const std::vector<T>& classes){CHECK(data.size() != 0 && classes.size() != 0 && data[0].size() != 0);this->samples_num = data.size();this->classes_num = classes.size();this->feature_length = data[0].size() -1;for (int i = 0; i < this->samples_num; ++i) {this->src_data.emplace_back(data[i]);}return 0;}template<typename T>T DecisionTree<T>::gini_index(const std::vector<std::vector<std::vector<T>>>& groups, const std::vector<T>& classes) const {// Gini calculation for a group// proportion = count(class_value) / count(rows)// gini_index = (1.0 - sum(proportion * proportion)) * (group_size/total_samples)// count all samples at split pointint instances = 0;int group_num = groups.size();for (int i = 0; i < group_num; ++i) {instances += groups[i].size();}// sum weighted Gini index for each groupT gini = (T)0.;for (int i = 0; i < group_num; ++i) {int size = groups[i].size();// avoid divide by zeroif (size == 0) continue;if (size == 0) continue;T score = (T)0.;// score the group based on the score for each classT p = (T)0.;for (int c = 0; c < classes.size(); ++c) {int count = 0;for (int t = 0; t < size; ++t) {if (groups[i][t][this->feature_length] == classes[c]) ++count;}T p = (float)count / size;score += p * p;}// weight the group score by its relative sizegini += (1. - score) * (float)size / instances;}return gini;}template<typename T>std::vector<std::vector<std::vector<T>>> DecisionTree<T>::test_split(int index, T value, const std::vector<std::vector<T>>& dataset) const {std::vector<std::vector<std::vector<T>>> groups(2); // 0: left, 1: reightfor (int row = 0; row < dataset.size(); ++row) {if (dataset[row][index] < value) {groups[0].emplace_back(dataset[row]);} else {groups[1].emplace_back(dataset[row]);}}return groups;}template<typename T>std::tuple<int, T, std::vector<std::vector<std::vector<T>>>> DecisionTree<T>::get_split(const std::vector<std::vector<T>>& dataset) const {std::vector<T> values;for (int i = 0; i < dataset.size(); ++i) {values.emplace_back(dataset[i][this->feature_length]);}std::set<T> vals(values.cbegin(), values.cend());std::vector<T> class_values(vals.cbegin(), vals.cend());int b_index = 999;T b_value = (T)999.;T b_score = (T)999.;std::vector<std::vector<std::vector<T>>> b_groups(2);for (int index = 0; index < this->feature_length; ++index) {for (int row = 0; row < dataset.size(); ++row) {std::vector<std::vector<std::vector<T>>> groups = test_split(index, dataset[row][index], dataset);T gini = gini_index(groups, class_values);if (gini < b_score) {b_index = index;b_value = dataset[row][index];b_score = gini;b_groups = groups;}}}}// a new node: the index of the chosen attribute, the value of that attribute by which to split and the two groups of data split by the chosen split point return std::make_tuple(b_index, b_value, b_groups);}template<typename T>T DecisionTree<T>::to_terminal(const std::vector<std::vector<T>>& group) const{std::vector<T> values;for (int i = 0; i < group.size(); ++i) {values.emplace_back(group[i][this->feature_length]);}std::set<T> vals(values.cbegin(), values.cend());int max_count = -1, index = -1;for (int i = 0; i < vals.size(); ++i) {int count = std::count(values.cbegin(), values.cend(), *std::next(vals.cbegin(), i));if (max_count < count) {max_count = count;index = i;}}return *std::next(vals.cbegin(), index);}template<typename T>void DecisionTree<T>::split(binary_tree* node, int depth){std::vector<std::vector<T>> left = std::get<2>(node->dict)[0];std::vector<std::vector<T>> right = std::get<2>(node->dict)[1];std::get<2>(node->dict).clear();// check for a no splitif (left.size() == 0 || right.size() == 0) {for (int i = 0; i < right.size(); ++i) {left.emplace_back(right[i]);}node->class_value_left = node->class_value_right = to_terminal(left);return;}// check for max depthif (depth >= max_depth) {node->class_value_left = to_terminal(left);node->class_value_right = to_terminal(right);return;}// process left childif (left.size() <= min_size) {node->class_value_left = to_terminal(left);} else {dictionary dict = get_split(left);node->left = new binary_tree;node->left->dict = dict;split(node->left, depth+1);}// process right childif (right.size() <= min_size) {node->class_value_right = to_terminal(right);} else {dictionary dict = get_split(right);dictionary dict = get_split(right);node->right = new binary_tree;node->right->dict = dict;split(node->right, depth+1);}}template<typename T>void DecisionTree<T>::build_tree(const std::vector<std::vector<T>>& train){// create root nodedictionary root = get_split(train);binary_tree* node = new binary_tree;node->dict = root;tree = node;split(node, 1);}template<typename T>void DecisionTree<T>::train(){this->max_nodes = (1 << max_depth) - 1;build_tree(src_data);accuracy_metric();//binary_tree* tmp = tree;//print_tree(tmp);}template<typename T>T DecisionTree<T>::predict(const std::vector<T>& data) const{if (!tree) {fprintf(stderr, "Error, tree is null\n");return -1111.f;}return predict(tree, data);}template<typename T>T DecisionTree<T>::predict(binary_tree* node, const std::vector<T>& data) const {if (data[std::get<0>(node->dict)] < std::get<1>(node->dict)) {if (node->left) {return predict(node->left, data);} else {return node->class_value_left;}} else {if (node->right) {return predict(node->right, data);} else {return node->class_value_right;}}}template<typename T>int DecisionTree<T>::save_model(const char* name) const{std::ofstream file(name, std::ios::out);if (!file.is_open()) {fprintf(stderr, "open file fail: %s\n", name);return -1;return -1;}file<<max_depth<<","<<min_size<<std::endl;binary_tree* tmp = tree;int depth = height_of_tree(tmp);CHECK(max_depth == depth);tmp = tree;write_node(tmp, file);file.close();return 0;}template<typename T>void DecisionTree<T>::write_node(const binary_tree* node, std::ofstream& file) const{/*if (!node) return;write_node(node->left, file);file<<std::get<0>(node->dict)<<","<<std::get<1>(node->dict)<<","<<node->class_value_left<<","<<node->class_value_right<<std::endl;write_node(node->right, file);*///typedef std::tuple<int, int, T, T, T> row; // flag, index, value, class_value_left, class_value_rightstd::vector<row_element> vec(this->max_nodes, std::make_tuple(-1, -1, (T)-1.f, (T)-1.f, (T)-1.f));binary_tree* tmp = const_cast<binary_tree*>(node);node_to_row_element(tmp, vec, 0);for (const auto& row : vec) {file<<std::get<0>(row)<<","<<std::get<1>(row)<<","<<std::get<2>(row)<<","<<std::get<3>(row)<<","<<std::get<4>(row)<<std::endl;}}template<typename T>void DecisionTree<T>::node_to_row_element(binary_tree* node, std::vector<row_element>& rows, int pos) const{if (!node) return;rows[pos] = std::make_tuple(0, std::get<0>(node->dict), std::get<1>(node->dict), node->class_value_left, node->class_value_right); // 0: have node, -1: no nodeif (node->left) node_to_row_element(node->left, rows, 2*pos+1);if (node->right) node_to_row_element(node->right, rows, 2*pos+2);}template<typename T>int DecisionTree<T>::height_of_tree(const binary_tree* node) const{if (!node)return 0;elsereturn std::max(height_of_tree(node->left), height_of_tree(node->right)) + 1;}template<typename T>int DecisionTree<T>::load_model(const char* name){std::ifstream file(name, std::ios::in);if (!file.is_open()) {fprintf(stderr, "open file fail: %s\n", name);return -1;}std::string line, cell;std::string line, cell;std::getline(file, line);std::stringstream line_stream(line);std::vector<int> vec;int count = 0;while (std::getline(line_stream, cell, ',')) {vec.emplace_back(std::stoi(cell));}CHECK(vec.size() == 2);max_depth = vec[0];min_size = vec[1];max_nodes = (1 << max_depth) - 1;std::vector<row_element> rows(max_nodes);if (typeid(float).name() == typeid(T).name()) {while (std::getline(file, line)) {std::stringstream line_stream2(line);std::vector<T> vec2;while(std::getline(line_stream2, cell, ',')) {vec2.emplace_back(std::stof(cell));}CHECK(vec2.size() == 5);rows[count] = std::make_tuple((int)vec2[0], (int)vec2[1], vec2[2], vec2[3], vec2[4]);//fprintf(stderr, "%d, %d, %f, %f, %f\n", std::get<0>(rows[count]), std::get<1>(rows[count]), std::get<2>(rows[count]), std::get<3>(rows[count]), std::get<4>(rows[c ++count;}} else { // doublewhile (std::getline(file, line)) {std::stringstream line_stream2(line);std::vector<T> vec2;while(std::getline(line_stream2, cell, ',')) {vec2.emplace_back(std::stod(cell));}CHECK(vec2.size() == 5);rows[count] = std::make_tuple((int)vec2[0], (int)vec2[1], vec2[2], vec2[3], vec[4]);++count;}}CHECK(max_nodes == count);CHECK(std::get<0>(rows[0]) != -1);binary_tree* tmp = new binary_tree;std::vector<std::vector<std::vector<T>>> dump;tmp->dict = std::make_tuple(std::get<1>(rows[0]), std::get<2>(rows[0]), dump);tmp->class_value_left = std::get<3>(rows[0]);tmp->class_value_right = std::get<4>(rows[0]);tree = tmp;row_element_to_node(tmp, rows, max_nodes, 0);file.close();return 0;}template<typename T>void DecisionTree<T>::row_element_to_node(binary_tree* node, const std::vector<row_element>& rows, int n, int pos){if (!node || n == 0) return;int new_pos = 2 * pos + 1;if (new_pos < n && std::get<0>(rows[new_pos]) != -1) {node->left = new binary_tree;node->left = new binary_tree;std::vector<std::vector<std::vector<T>>> dump;node->left->dict = std::make_tuple(std::get<1>(rows[new_pos]), std::get<2>(rows[new_pos]), dump); node->left->class_value_left = std::get<3>(rows[new_pos]);node->left->class_value_right = std::get<4>(rows[new_pos]);row_element_to_node(node->left, rows, n, new_pos);}new_pos = 2 * pos + 2;if (new_pos < n && std::get<0>(rows[new_pos]) != -1) {node->right = new binary_tree;std::vector<std::vector<std::vector<T>>> dump;node->right->dict = std::make_tuple(std::get<1>(rows[new_pos]), std::get<2>(rows[new_pos]), dump); node->right->class_value_left = std::get<3>(rows[new_pos]);node->right->class_value_right = std::get<4>(rows[new_pos]);row_element_to_node(node->right, rows, n, new_pos);}}template<typename T>void DecisionTree<T>::delete_tree(){delete_node(tree);}template<typename T>void DecisionTree<T>::delete_node(binary_tree* node){if (node->left) delete_node(node->left);if (node->right) delete_node(node->right);delete node;}template<typename T>double DecisionTree<T>::accuracy_metric() const{int correct = 0;for (int i = 0; i < this->samples_num; ++i) {T predicted = predict(tree, src_data[i]);if (predicted == src_data[i][this->feature_length])++correct;}double accuracy = correct / (double)samples_num * 100.;fprintf(stdout, "train accuracy: %f\n", accuracy);return accuracy;}template<typename T>void DecisionTree<T>::print_tree(const binary_tree* node, int depth) const{if (node) {std::string blank = " ";for (int i = 0; i < depth; ++i) blank += blank;fprintf(stdout, "%s[X%d < %.3f]\n", blank.c_str(), std::get<0>(node->dict)+1, std::get<1>(node->dict));if (!node->left || !node->right)blank += blank;if (!node->left)fprintf(stdout, "%s[%.1f]\n", blank.c_str(), node->class_value_left);elseprint_tree(node->left, depth+1);print_tree(node->left, depth+1);if (!node->right)fprintf(stdout, "%s[%.1f]\n", blank.c_str(), node->class_value_right);elseprint_tree(node->right, depth+1);}}template class DecisionTree<float>;template class DecisionTree<double>;} // namespace ANN对外提供两个接⼝,⼀个是test_decision_tree_train⽤于训练,⼀个是test_decision_tree_predict⽤于测试,其code如下:// =============================== decision tree ==============================int test_decision_tree_train(){// small dataset test/*const std::vector<std::vector<float>> data{ { 2.771244718f, 1.784783929f, 0.f },{ 1.728571309f, 1.169761413f, 0.f },{ 3.678319846f, 2.81281357f, 0.f },{ 3.961043357f, 2.61995032f, 0.f },{ 2.999208922f, 2.209014212f, 0.f },{ 7.497545867f, 3.162953546f, 1.f },{ 9.00220326f, 3.339047188f, 1.f },{ 7.444542326f, 0.476683375f, 1.f },{ 10.12493903f, 3.234550982f, 1.f },{ 6.642287351f, 3.319983761f, 1.f } };const std::vector<float> classes{ 0.f, 1.f };ANN::DecisionTree<float> dt;dt.init(data, classes);dt.set_max_depth(3);dt.set_min_size(1);dt.train();#ifdef _MSC_VERconst char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";#elseconst char* model_name = "data/decision_tree.model";#endifdt.save_model(model_name);ANN::DecisionTree<float> dt2;dt2.load_model(model_name);const std::vector<std::vector<float>> test{{0.6f, 1.9f, 0.f}, {9.7f, 4.3f, 1.f}};for (const auto& row : test) {float ret = dt2.predict(row);fprintf(stdout, "predict result: %.1f, actural value: %.1f\n", ret, row[2]);} */// banknote authentication dataset#ifdef _MSC_VERconst char* file_name = "E:/GitCode/NN_Test/data/database/BacknoteDataset/data_banknote_authentication.txt";#elseconst char* file_name = "data/database/BacknoteDataset/data_banknote_authentication.txt";#endifstd::vector<std::vector<float>> data;int ret = read_txt_file<float>(file_name, data, ',', 1372, 5);if (ret != 0) {fprintf(stderr, "parse txt file fail: %s\n", file_name);return -1;}//fprintf(stdout, "data size: rows: %d\n", data.size());const std::vector<float> classes{ 0.f, 1.f };ANN::DecisionTree<float> dt;dt.init(data, classes);dt.set_max_depth(6);dt.set_min_size(10);dt.train();#ifdef _MSC_VERconst char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model"; #elseconst char* model_name = "data/decision_tree.model";#endifdt.save_model(model_name);return 0;}int test_decision_tree_predict(){#ifdef _MSC_VERconst char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model"; #elseconst char* model_name = "data/decision_tree.model";#endifANN::DecisionTree<float> dt;dt.load_model(model_name);int max_depth = dt.get_max_depth();int min_size = dt.get_min_size();fprintf(stdout, "max_depth: %d, min_size: %d\n", max_depth, min_size);std::vector<std::vector<float>> test {{-2.5526,-7.3625,6.9255,-0.66811,1}, {-4.5531,-12.5854,15.4417,-1.4983,1},{4.0948,-2.9674,2.3689,0.75429,0},{-1.0401,9.3987,0.85998,-5.3336,0},{1.0637,3.6957,-4.1594,-1.9379,1}};for (const auto& row : test) {float ret = dt.predict(row);fprintf(stdout, "predict result: %.1f, actual value: %.1f\n", ret, row[4]);}return 0;}训练接⼝执⾏结果如下:测试接⼝执⾏结果如下:训练时⽣成的模型decison_tree.model内容如下:6,100,0,0.3223,-1,-1 0,1,7.6274,-1,-1 0,2,-4.3839,-1,-1 0,0,-0.39816,-1,-1 0,0,-4.2859,-1,-1 0,0,4.2164,-1,0 0,0,1.594,-1,-1 0,2,6.2204,-1,-1 0,1,5.8974,-1,-1 0,0,-5.4901,-1,1 0,0,-1.5768,-1,-1 0,0,0.47368,1,-1 -1,-1,-1,-1,-10,2,-2.2718,-1,-1 0,0,2.0421,-1,-1 0,1,7.3273,-1,1 0,1,-4.6062,-1,-1 0,2,3.1143,-1,-1 0,0,0.049175,0,0 0,0,-6.2003,1,1-1,-1,-1,-1,-10,0,-2.7419,0,-1 0,0,-1.5768,0,0-1,-1,-1,-1,-10,0,0.47368,1,1 -1,-1,-1,-1,-1-1,-1,-1,-1,-10,1,7.6377,-1,0 0,3,0.097399,-1,-1 0,2,-2.3386,1,-1 0,0,3.6216,-1,-1 0,0,-1.3971,1,1-1,-1,-1,-1,-10,0,-1.6677,1,1 0,0,-1.7781,0,0 0,0,-0.36506,1,1 0,3,1.547,0,1-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-10,0,-2.7419,0,0-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-1-1,-1,-1,-1,-10,0,1.0552,1,1-1,-1,-1,-1,-10,0,0.4339,0,0 0,2,2.0013,1,0-1,-1,-1,-1,-10,0,1.8993,0,0 0,0,3.4566,0,0 0,0,3.6216,0,0。
cart回归树算法例子
Cart回归树是一种基于决策树的回归算法,其主要目的是通过训练数据来构建一棵决策树,以预测新数据的响应变量。
下面是一个Cart回归树算法的例子:
假设我们有一组训练数据包括房屋面积和售价的数据,我们想通过这些数据来构建一个回归树,以预测新房屋的售价。
首先,我们需要将训练数据按照房屋面积进行排序,然后选择一个分裂点,将数据集分为两个子集,使得每个子集内的数据尽量相似,同时让分裂点左边的数据集的响应变量的方差尽可能小,右边的数据集的响应变量的方差尽可能小。
我们可以通过计算划分前后的均方误差(MSE)来选择最优的分裂点。
接下来,我们将左右两个子集按照同样的方法进行分裂,直到达到某种停止条件(例如,达到了最大树深度或子集大小小于某个阈值)。
最终得到一棵回归树,可以用来预测新房屋的售价。
在预测新数据时,我们从根节点开始遍历树,按照每个节点的分裂规则选择左右子树,直到到达叶节点。
叶节点存储了一个预测值,即该节点下的所有训练数据的响应变量的平均值。
预测值就是遍历过程中所有叶节点的预测值的加权平均值。
这是一个简单的Cart回归树算法的例子,该算法可以应用于许多不同的回归问题。
- 1 -。
决策树算法(CART分类树) 在中,提到C4.5的不⾜,⽐如模型是⽤较为复杂的熵来度量,使⽤了相对较为复杂的多叉树,只能处理分类不能处理回归。
对这些问题,CART(Classification And Regression Tree)做了改进,可以处理分类,也可以处理回归。
1. CART分类树算法的最优特征选择⽅法 ID3中使⽤了信息增益选择特征,增益⼤优先选择。
C4.5中,采⽤信息增益⽐选择特征,减少因特征值多导致信息增益⼤的问题。
CART分类树算法使⽤基尼系数来代替信息增益⽐,基尼系数代表了模型的不纯度,基尼系数越⼩,不纯度越低,特征越好。
这和信息增益(⽐)相反。
假设K个类别,第k个类别的概率为p k,概率分布的基尼系数表达式: 如果是⼆分类问题,第⼀个样本输出概率为p,概率分布的基尼系数表达式为: 对于样本D,个数为|D|,假设K个类别,第k个类别的数量为|C k|,则样本D的基尼系数表达式: 对于样本D,个数为|D|,根据特征A的某个值a,把D分成|D1|和|D2|,则在特征A的条件下,样本D的基尼系数表达式为: ⽐较基尼系数和熵模型的表达式,⼆次运算⽐对数简单很多。
尤其是⼆分类问题,更加简单。
和熵模型的度量⽅式⽐,基尼系数对应的误差有多⼤呢?对于⼆类分类,基尼系数和熵之半的曲线如下: 基尼系数和熵之半的曲线⾮常接近,因此,基尼系数可以做为熵模型的⼀个近似替代。
CART分类树算法每次仅对某个特征的值进⾏⼆分,⽽不是多分,这样CART分类树算法建⽴起来的是⼆叉树,⽽不是多叉树。
2. CART分类树算法具体流程 CART分类树建⽴算法流程,之所以加上建⽴,是因为CART分类树算法有剪枝算法流程。
算法输⼊训练集D,基尼系数的阈值,样本个数阈值。
输出的是决策树T。
算法从根节点开始,⽤训练集递归建⽴CART分类树。
(1)、对于当前节点的数据集为D,如果样本个数⼩于阈值或没有特征,则返回决策⼦树,当前节点停⽌递归。
cart决策树例题简单案例决策树是一种常用的机器学习算法,可以用于分类和回归问题。
它通过对特征进行划分来建立一个树状的决策流程,从而对新的样本进行预测或分类。
在本文中,我们将通过一个简单的案例来介绍决策树的基本原理和应用。
假设我们有一个购物车数据集,其中包含了一些特征和对应的标签。
我们的目标是根据这些特征来预测一个购物车是否会购买商品。
首先,我们需要加载数据集并进行数据预处理。
数据预处理的目的是将原始数据转换为适用于决策树算法的格式。
我们可以使用Python的pandas库来完成这些任务。
```pythonimport pandas as pd# 加载数据集data = pd.read_csv('shopping_cart.csv')# 数据预处理# ...```接下来,我们需要选择用于构建决策树的特征。
在这个例子中,我们假设特征包括购买的商品种类、购买的商品数量以及购物车的总价。
根据经验,我们可以选择购买的商品种类和购物车的总价作为特征,因为它们可能与购买行为更相关。
然后,我们将数据集分为训练集和测试集。
训练集用于构建决策树模型,而测试集用于评估模型的性能。
```pythonfrom sklearn.model_selection import train_test_split# 选择特征和标签X = data[['商品种类', '购物车总价']]y = data['购买']# 划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,random_state=42)```接下来,我们可以使用scikit-learn库中的DecisionTreeClassifier类来构建决策树模型。
```pythonfrom sklearn.tree import DecisionTreeClassifier# 构建决策树模型model = DecisionTreeClassifier()# 在训练集上训练模型model.fit(X_train, y_train)```训练完成后,我们可以使用训练好的模型对测试集进行预测,并评估模型的性能。
数据挖掘实践(50):决策树计算过程实例(四)CART树算法分类来源:https:///e15273/article/details/79648502⼀算法步骤CART假设决策树是⼆叉树,内部结点特征的取值为“是”和“否”,左分⽀是取值为“是”的分⽀,右分⽀是取值为“否”的分⽀。
这样的决策树等价于递归地⼆分每个特征,将输⼊空间即特征空间划分为有限个单元,并在这些单元上确定预测的概率分布,也就是在输⼊给定的条件下输出的条件概率分布。
CART算法由以下两步组成:1. 决策树⽣成:基于训练数据集⽣成决策树,⽣成的决策树要尽量⼤;决策树剪枝:⽤验证数据集对已⽣成的树进⾏剪枝并选择最优⼦树,这时损失函数最⼩作为剪枝的标准。
2. CART决策树的⽣成就是递归地构建⼆叉决策树的过程。
CART决策树既可以⽤于分类也可以⽤于回归。
本⽂我们仅讨论⽤于分类的CART。
对分类树⽽⾔,CART⽤Gini系数最⼩化准则来进⾏特征选择,⽣成⼆叉树。
CART⽣成算法如下:输⼊:训练数据集D,停⽌计算的条件:输出:CART决策树。
根据训练数据集,从根结点开始,递归地对每个结点进⾏以下操作,构建⼆叉决策树:设结点的训练数据集为D,计算现有特征对该数据集的Gini系数。
此时,对每⼀个特征A,对其可能取的每个值a,根据样本点对A=a的测试为“是”或 “否”将D分割成D1和D2两部分,计算A=a时的Gini系数。
在所有可能的特征A以及它们所有可能的切分点a中,选择Gini系数最⼩的特征及其对应的切分点作为最优特征与最优切分点。
依最优特征与最优切分点,从现结点⽣成两个⼦结点,将训练数据集依特征分配到两个⼦结点中去。
对两个⼦结点递归地调⽤步骤l~2,直⾄满⾜停⽌条件。
⽣成CART决策树。
算法停⽌计算的条件是结点中的样本个数⼩于预定阈值,或样本集的Gini系数⼩于预定阈值(样本基本属于同⼀类),或者没有更多特征。
⼆ Gini指数的计算其实gini指数最早应⽤在经济学中,主要⽤来衡量收⼊分配公平度的指标。
cart决策树例题简单案例
决策树是一种常见的机器学习算法,用于分类和预测分析。
它
通过一系列规则和条件来对数据进行分类或预测,类似于真实世界
中的决策过程。
下面我将给你一个简单的购物车决策树的例子。
假设我们有一个购物车决策树,用于预测一个顾客是否会购买
某种产品。
我们收集了一些顾客的数据,包括年龄、性别、收入和
是否有小孩。
我们想要通过这些数据来预测顾客是否会购买某种产品。
首先,我们可以使用年龄作为第一个分裂节点。
如果顾客年龄
小于30岁,则我们进一步考虑性别;如果是女性,则我们再考虑收入;如果收入高于某个阈值,则预测她会购买;如果收入低于阈值,则再考虑是否有小孩,如果有小孩则预测她会购买。
如果是男性,
则我们可能会根据其他特征进行进一步的分裂。
这只是一个简单的例子,实际上,决策树可以根据具体情况进
行更复杂的分裂和预测。
在实际应用中,决策树可以用于各种领域,如金融、医疗和市场营销等,用来预测客户购买行为、疾病风险等。
总的来说,决策树是一种直观且易于理解的机器学习算法,它可以帮助我们从数据中发现规律,并做出有效的预测和决策。
希望这个简单的例子可以帮助你更好地理解决策树的应用和工作原理。
cart决策树案例
CART(Classification and Regression Trees)决策树是一种常用的机器学习算法,既可以用于分类问题,也可以用于回归问题。
下面是一个使用CART决策树解决分类问题的案例:
案例背景:一家电商网站想要预测用户是否会购买某商品,以便更好地进行商品推荐。
为此,他们收集了一些用户数据,包括用户的年龄、性别、购买历史、浏览历史等。
数据准备:首先,对数据进行预处理,包括缺失值处理、异常值处理、数据规范化等。
例如,对于年龄这一特征,可以将数据规范化到0-1之间。
特征选择:根据业务需求和数据特点,选择合适的特征进行建模。
例如,在本案例中,可以选择年龄、性别、购买历史、浏览历史等特征进行建模。
模型训练:使用CART决策树算法对数据进行训练,生成预测模型。
在本案例中,目标变量是用户是否购买某商品,因此这是一个二分类问题。
模型评估:使用测试集对模型进行评估,计算模型的准确率、精确率、召回率等指标。
如果模型表现不佳,需要对模型进行调整和优化。
应用场景:生成的模型可以应用于实际的电商推荐系统中,根据用户的历史数据和浏览行为等信息,预测用户是否会购买某商品,并据此进行商品推荐。
这只是一个简单的CART决策树分类案例,实际应用中可能还需要考虑更多的因素和细节。