#pragma once #include #include #include #include #include #include #include #include "question.h" using namespace std; using namespace Eigen; struct Node_reg { Question Q; Node_reg* left = nullptr; Node_reg* right = nullptr; double labels; ~Node_reg() { delete left; delete right; } }; class RegressionTree { private: Node_reg* root= nullptr; public: RegressionTree(); ~RegressionTree(); void fit(const MatrixXd& X, const VectorXd& Y); VectorXd predict(const MatrixXd& X); double test_pred(const MatrixXd& X, const VectorXd& Y); void printTree(); void save(ofstream &ff); void rebuild(ifstream &ff,int classes); }; namespace rt { vector init_split(const VectorXd& Y); vector init_cols(int size); void build_tree(Node_reg*& node, const MatrixXd& X, const VectorXd& Y, const vector& split, const vector& cols); Question find_best_question(const MatrixXd& X, const VectorXd& Y, const vector& split, const vector& cols, double& best_gain); double mse(const VectorXd& Y, const vector& split); double cal_avg(const VectorXd& Y, const vector& split); int n_classes(const VectorXd& Y); vector unique_values(const VectorXd& col, const vector& split); bool isOverlap(const vector& unique, double value); vector> split_Node_reg(const Question& Q, const MatrixXd& X, const vector& split); double info_gain_reg(const VectorXd& Y, const vector& left, const vector& right, double current); vector erase_taken_col(const Question& Q, const vector& cols); double predict_implementation_reg(const RowVectorXd& x, Node_reg* node); void print_implementation(Node_reg* node, int64_t wirth); void rebuild_implement(std::ifstream & ff, Node_reg*& node,int class_num); void save_implement(std::ofstream & ff ,Node_reg*& node); } RegressionTree::RegressionTree() : root(nullptr) {} RegressionTree::~RegressionTree() { //delete root; } void RegressionTree::fit(const MatrixXd& X, const VectorXd& Y) { rt::build_tree(root, X, Y, rt::init_split(Y), rt::init_cols((int)X.cols())); } vector rt::init_split(const VectorXd& Y) { vector split; for (int i = 0; i < Y.size(); i++) split.push_back(i); return split; } vector rt::init_cols(int size) { vector cols; for (int i = 0; i < size; i++) cols.push_back(i); return cols; } void rt::build_tree(Node_reg*& node, const MatrixXd& X, const VectorXd& Y, const vector& split, const vector& cols) { double gain; Question Q = find_best_question(X, Y, split, cols, gain); node = new Node_reg; node->labels = cal_avg(Y, split); if (gain >= 0.08) { node->Q = Q; vector> splits = split_Node_reg(Q, X, split); vector new_cols = erase_taken_col(Q, cols); build_tree(node->left, X, Y, splits[0], new_cols); build_tree(node->right, X, Y, splits[1], new_cols); } } Question rt::find_best_question(const MatrixXd& X, const VectorXd& Y, const vector& split, const vector& cols, double& best_gain) { Question best_Q; best_gain = 0; double current_uncertainty = mse(Y, split); for (int idx : cols) { vector unique = unique_values(X.col(idx), split); for (double value : unique) { Question Q(idx, value); vector> splits = split_Node_reg(Q, X, split); if (splits[0].size() == 0 || splits[1].size() == 0) continue; double gain = info_gain_reg(Y, splits[0], splits[1], current_uncertainty); if (gain >= best_gain) { best_gain = gain; best_Q = Q; } } } return best_Q; } double rt::mse(const VectorXd& Y, const vector& split) { double avg = cal_avg(Y,split); double sum = 0; for (int idx : split) { double truth = (double)Y[idx]; sum += std::pow(truth-avg,2); } return sum; } double rt::cal_avg(const VectorXd& Y, const vector& split) { double sum = 0; for (int idx : split) { double truth = (double)Y[idx]; sum+=truth; } double avg = sum/split.size(); return avg; } int rt::n_classes(const VectorXd& Y) { return int(*std::max_element(Y.data(), Y.data() + Y.size()) + 1); } vector rt::unique_values(const VectorXd& col, const vector& split) { vector unique; for (auto idx : split) if (!isOverlap(unique, col[idx])) unique.push_back(col[idx]); return unique; } bool rt::isOverlap(const vector& unique, double value) { auto iter = std::find(unique.begin(), unique.end(), value); return iter != unique.end(); } vector> rt::split_Node_reg(const Question& Q, const MatrixXd& X, const vector& split) { vector> splits(2, vector()); for (int idx : split) { if (Q.match(X.row(idx))) splits[0].push_back(idx); else splits[1].push_back(idx); } return splits; } double rt::info_gain_reg(const VectorXd& Y, const vector& left, const vector& right, double current) { return current - (mse(Y,left)+mse(Y,right)); } vector rt::erase_taken_col(const Question& Q, const vector& cols) { vector new_cols; for (int idx : cols) if (idx != Q.getCol()) new_cols.push_back(idx); return new_cols; } VectorXd RegressionTree::predict(const MatrixXd& X) { VectorXd labels(X.rows()); for (int i = 0; i < X.rows(); i++) labels[i] = rt::predict_implementation_reg(X.row(i), root); return labels; } double RegressionTree::test_pred(const MatrixXd& X, const VectorXd& Y){ VectorXd pred(X.rows()); pred = predict(X); double total=0; for (int i = 0; i < X.rows(); i++){ total+=pow(pred[i]-Y[i],2); } return (double)total/X.rows(); } double rt::predict_implementation_reg(const RowVectorXd& x, Node_reg* node) { while (node->left != nullptr && node->right != nullptr) { if (node->Q.match(x)) node = node->left; else node = node->right; } auto avg = node->labels; return avg; } void RegressionTree::printTree() { rt::print_implementation(root, 0); } void RegressionTree::rebuild(std::ifstream & ff,int class_num){ rt::rebuild_implement(ff,root,class_num); printTree(); } void rt::rebuild_implement(std::ifstream & ff, Node_reg*& node,int class_num){ if(ff.eof()){ return ; } char c; ff>>c; node = new Node_reg; if (c=='Q'){ int tmpcol; double tmpval; ff>>tmpcol>>tmpval; Question tmpq(tmpcol,tmpval); node->Q=tmpq; rt::rebuild_implement(ff,node->left,class_num); rt::rebuild_implement(ff,node->right,class_num); } if(c=='L'){ double val; ff>>val; node->labels = val; return; } } void RegressionTree::save(std::ofstream & ff ){ rt::save_implement(ff,root); } void rt::save_implement(std::ofstream & ff ,Node_reg*& node){ // Q attribution value / # / L label if (node->left == nullptr && node->right == nullptr){ auto max = node->labels; ff<<'L'<<" "<Q.getCol()<<" "<Q.getValue()<left); rt::save_implement(ff,node->right); } } void rt::print_implementation(Node_reg* node, int64_t wirth) { if (node->left == nullptr && node->right == nullptr) { cout << setw(wirth + 4) << " " << "Predict : {"; cout << node->labels; cout << "}" << endl; } else { cout << setw(wirth) << " " << "Q : X" << node->Q.getCol() + 1 << " >= " << node->Q.getValue() << " ? " << endl; cout << setw(wirth) << " " << "--> True: " << endl; print_implementation(node->left, wirth + 4); cout << setw(wirth) << " " << "--> False: " << endl; print_implementation(node->right, wirth + 4); } }