11#include " specialized_solver.h"
2+ #include " general_solver.h"
23
34class Depth1ScoreHelper {
45public:
@@ -43,19 +44,19 @@ class Depth1ScoreHelper {
4344 std::vector<int > current_label_frequency;
4445};
4546
46- void SpecializedSolver::get_best_left_right_scores (const Dataview& dataview, int feature_index, int split_point, float threshold, std::shared_ptr<Tree> &left_optimal_dt, std::shared_ptr<Tree> &right_optimal_dt , float upper_bound, float complexity_cost) {
47+ void SpecializedSolver::get_best_left_right_scores (const Dataview& dataview, SplitInfo& split , float upper_bound, float complexity_cost) {
4748 RUNTIME_ASSERT (dataview.get_dataset_size () > 0 , " Dataset cannot be empty." );
4849 RUNTIME_ASSERT (upper_bound >= complexity_cost, " Upper bound should always be complexity-cost or higher." );
49- RUNTIME_ASSERT (left_optimal_dt->objective >= 0 , " Current objective should always be zero or higher." );
50- RUNTIME_ASSERT (right_optimal_dt->objective >= 0 , " Current objective should always be zero or higher." );
50+ RUNTIME_ASSERT (split. left_optimal_dt ->objective >= 0 , " Current objective should always be zero or higher." );
51+ RUNTIME_ASSERT (split. right_optimal_dt ->objective >= 0 , " Current objective should always be zero or higher." );
5152
52- const auto & split_feature = dataview.get_sorted_dataset_feature (feature_index );
53- const auto & unsorted_split_feature = dataview.get_unsorted_dataset_feature (feature_index );
53+ const auto & split_feature = dataview.get_sorted_dataset_feature (split. feature );
54+ const auto & unsorted_split_feature = dataview.get_unsorted_dataset_feature (split. feature );
5455 std::vector<int > split_feature_split_indices (unsorted_split_feature.size ());
5556 int split_index = -1 ;
5657 for (const auto & split_feature_data : split_feature) {
5758 split_feature_split_indices[split_feature_data.data_point_index ] = split_feature_data.unique_value_index ;
58- if (split_index == -1 && split_feature_data.value >= threshold ) {
59+ if (split_index == -1 && split_feature_data.unique_value_index >= split. unique_value_threshold ) {
5960 split_index = split_feature_data.unique_value_index ;
6061 }
6162 }
@@ -64,15 +65,15 @@ void SpecializedSolver::get_best_left_right_scores(const Dataview& dataview, int
6465 const int dataset_size = dataview.get_dataset_size ();
6566 const int class_number = dataview.get_class_number ();
6667
67- RUNTIME_ASSERT (split_point > 0 && split_point < dataset_size, " left and right subtree need to be non-empty." );
68- Depth1ScoreHelper left_tree (split_point, class_number);
69- Depth1ScoreHelper right_tree (dataset_size - split_point, class_number);
68+ RUNTIME_ASSERT (split. split_point > 0 && split. split_point < dataset_size, " left and right subtree need to be non-empty." );
69+ Depth1ScoreHelper left_tree (split. split_point , class_number);
70+ Depth1ScoreHelper right_tree (dataset_size - split. split_point , class_number);
7071
7172
7273 left_tree.classification_score = std::max (0 .0f , float (left_tree.size ) - upper_bound);
7374 right_tree.classification_score = std::max (0 .0f , float (right_tree.size ) - upper_bound);
7475
75- Dataview::initialize_split_parameters (split_feature, class_number, dataview.get_label_frequency (), split_point, left_tree.label_frequency , right_tree.label_frequency );
76+ Dataview::initialize_split_parameters (split_feature, class_number, dataview.get_label_frequency (), split. split_point , left_tree.label_frequency , right_tree.label_frequency );
7677
7778 left_tree.max_label_frequency = 0 ;
7879 right_tree.max_label_frequency = 0 ;
@@ -91,38 +92,38 @@ void SpecializedSolver::get_best_left_right_scores(const Dataview& dataview, int
9192
9293 for (int current_feature_index = 0 ; current_feature_index < dataview.get_feature_number (); current_feature_index++) {
9394 if (left_tree.classification_score + right_tree.classification_score == dataset_size) break ;
94- if (current_feature_index == feature_index ) {
95- process_depth_one_feature<true >(dataview, feature_index, split_point, current_feature_index, split_index,
95+ if (current_feature_index == split. feature ) {
96+ process_depth_one_feature<true >(dataview, split. feature , split. split_point , current_feature_index, split_index,
9697 left_tree, right_tree, split_feature_split_indices, upper_bound, complexity_cost);
9798 } else {
98- process_depth_one_feature<false >(dataview, feature_index, split_point, current_feature_index, split_index,
99+ process_depth_one_feature<false >(dataview, split. feature , split. split_point , current_feature_index, split_index,
99100 left_tree, right_tree, split_feature_split_indices, upper_bound, complexity_cost);
100101 }
101102 }
102103
103104 if (left_tree.is_leaf ()) {
104- left_optimal_dt->make_leaf (left_tree.max_label , left_tree.size - left_tree.classification_score );
105+ split. left_optimal_dt ->make_leaf (left_tree.max_label , left_tree.size - left_tree.classification_score );
105106 } else {
106- left_optimal_dt->update_split (left_tree.best_feature_index , left_tree.best_threshold ,
107+ split. left_optimal_dt ->update_split (left_tree.best_feature_index , left_tree.best_threshold ,
107108 std::make_shared<Tree>(left_tree.best_left_label , -1 ),
108109 std::make_shared<Tree>(left_tree.best_right_label , -1 ), complexity_cost);
109- left_optimal_dt->objective = left_tree.get_objective (complexity_cost);
110+ split. left_optimal_dt ->objective = left_tree.get_objective (complexity_cost);
110111 // RUNTIME_ASSERT(left_tree.best_left_label != -1, "Left tree left label should be initialized.");
111112 // RUNTIME_ASSERT(left_tree.best_right_label != -1, "Left tree right label should be initialized.");
112113 }
113- RUNTIME_ASSERT (left_optimal_dt->objective >= 0 , " LR - Left tree misclassification score should be non-negative." );
114+ RUNTIME_ASSERT (split. left_optimal_dt ->objective >= 0 , " LR - Left tree misclassification score should be non-negative." );
114115
115116 if (right_tree.is_leaf ()) {
116- right_optimal_dt->make_leaf (right_tree.max_label , right_tree.size - right_tree.classification_score );
117+ split. right_optimal_dt ->make_leaf (right_tree.max_label , right_tree.size - right_tree.classification_score );
117118 } else {
118- right_optimal_dt->update_split (right_tree.best_feature_index , right_tree.best_threshold ,
119+ split. right_optimal_dt ->update_split (right_tree.best_feature_index , right_tree.best_threshold ,
119120 std::make_shared<Tree>(right_tree.best_left_label , -1 ),
120121 std::make_shared<Tree>(right_tree.best_right_label , -1 ), complexity_cost);
121- right_optimal_dt->objective = right_tree.get_objective (complexity_cost);
122+ split. right_optimal_dt ->objective = right_tree.get_objective (complexity_cost);
122123 // RUNTIME_ASSERT(right_tree.best_left_label != -1, "Right tree left label should be initialized.");
123124 // RUNTIME_ASSERT(right_tree.best_right_label != -1, "Right tree right label should be initialized.");
124125 }
125- RUNTIME_ASSERT (right_optimal_dt->objective >= 0 , " LR - Right tree misclassification score should be non-negative." );
126+ RUNTIME_ASSERT (split. right_optimal_dt ->objective >= 0 , " LR - Right tree misclassification score should be non-negative." );
126127}
127128
128129template <bool is_same_feature>
@@ -159,11 +160,13 @@ void SpecializedSolver::process_depth_one_feature(const Dataview& dataview,
159160
160161 tree.can_skip --;
161162
162- if (current_feature_data.unique_value_index == tree.previous_unique_value_index || tree.can_skip > 0 ){
163+ if (current_feature_data.unique_value_index == tree.previous_unique_value_index || tree.can_skip > 0 ) {
163164 tree.current_element_count ++;
164165 tree.current_label_frequency [current_feature_data.label ]++;
165166 tree.previous_value = current_feature_data.value ;
166167 tree.previous_unique_value_index = current_feature_data.unique_value_index ;
168+ RUNTIME_ASSERT (tree.current_label_frequency [current_feature_data.label ] <= tree.label_frequency [current_feature_data.label ]
169+ , " Current label frequency can never exceed the maximum label frequency." );
167170 continue ;
168171 }
169172
0 commit comments