Skip to content

Commit 2a62880

Browse files
committed
bug fix static stats to dynamic stats
1 parent 0fe6b4d commit 2a62880

14 files changed

Lines changed: 71 additions & 79 deletions

code/Engine/include/general_solver.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "cache.h"
1010
#include "dataset.h"
1111
#include "dataview.h"
12-
#include "general_solver.h"
1312
#include "intervals_pruner.h"
1413
#include "specialized_solver.h"
1514
#include "statistics.h"
@@ -31,6 +30,7 @@ struct SplitInfo {
3130
int mid = -1; // the mid point in the interval, the split point we are splitting on now
3231
int split_point = -1; // the sample index of the first sample right of the split
3332
float threshold = -1; // the split threshold (in the continuous numbers, not the unique_value_index)
33+
int unique_value_threshold = -1; // the split threshold (in the unique value index, not the continuous value)
3434
std::shared_ptr<Tree> left_optimal_dt;
3535
std::shared_ptr<Tree> right_optimal_dt;
3636
};
@@ -72,7 +72,7 @@ class GeneralSolver {
7272
* @param current_optimal_tree The current optimal tree.
7373
* @param upper_bound The upper bound for the search space.
7474
*/
75-
static void solve_split(const Dataview& dataview, const Configuration& config, SplitInfo& split, std::shared_ptr<Tree>& current_optimal_tree, float upper_bound);
75+
static void solve_split(const Dataview& dataview, Configuration& config, SplitInfo& split, std::shared_ptr<Tree>& current_optimal_tree, float upper_bound);
7676

7777
/**
7878
* Calculates the misclassification score if the current node is a leaf node.

code/Engine/include/specialized_solver.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,25 @@
1010
#include "dataset.h"
1111
#include "dataview.h"
1212
#include "intervals_pruner.h"
13-
#include "specialized_solver.h"
1413
#include "statistics.h"
1514
#include "tree.h"
1615

16+
1717
class Depth1ScoreHelper;
1818

19+
struct SplitInfo;
20+
1921
class SpecializedSolver {
2022
public:
2123
/**
2224
* Calculates the misclassification scores for both the left and right splits of the dataset using only one dataset traversal.
2325
*
2426
* @param dataset The dataset to calculate the scores from.
25-
* @param feature_index The index of the feature to split on.
26-
* @param split_point The split point for the feature.
27-
* @param threshold The threshold value for the split.
28-
* @param left_optimal_tree The optimal tree for the left split.
29-
* @param right_optimal_tree The optimal tree for the right split.
27+
* @param split Information on the root node split, such as what feature to split on
3028
* @param upper_bound The upper bound for the scores.
3129
* @param complexity_cost The cost of adding a node
3230
*/
33-
static void get_best_left_right_scores(const Dataview& dataset, int feature_index, int split_point, float threshold,
34-
std::shared_ptr<Tree>& left_optimal_tree, std::shared_ptr<Tree>& right_optimal_tree, float upper_bound, float complexity_cost);
31+
static void get_best_left_right_scores(const Dataview& dataset, SplitInfo& split, float upper_bound, float complexity_cost);
3532

3633
private:
3734
/**

code/Engine/src/general_solver.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ void GeneralSolver::create_optimal_decision_tree(const Dataview& dataview, Confi
88
// Check if we have the subproblem cached already
99
if (Cache::global_cache.is_cached(dataview, config.max_depth)) {
1010
current_optimal_decision_tree = Cache::global_cache.retrieve(dataview, config.max_depth);
11-
statistics::total_number_cache_hits += 1;
11+
config.stats->total_number_cache_hits += 1;
1212
return;
1313
}
1414

@@ -57,6 +57,7 @@ SplitInfo initialize_split(const std::vector<int>& possible_split_indices, const
5757
split.split_point = possible_split_indices[(left + right) / 2];
5858
split.threshold = split.mid > 0 ? (current_feature[possible_split_indices[split.mid - 1]].value + current_feature[split.split_point].value) / 2.0f
5959
: (current_feature[split.split_point].value + current_feature[0].value) / 2.0f;
60+
split.unique_value_threshold = current_feature[split.split_point].unique_value_index;
6061
split.left_optimal_dt = std::make_shared<Tree>();
6162
split.right_optimal_dt = std::make_shared<Tree>();
6263
return split;
@@ -110,12 +111,12 @@ void GeneralSolver::create_optimal_decision_tree(const Dataview& dataview, Confi
110111
if (split.left_optimal_dt->is_initialized() && split.right_optimal_dt->is_initialized() && current_best_score < current_optimal_decision_tree->objective) {
111112
current_optimal_decision_tree->update_split(feature_index, split.threshold, split.left_optimal_dt, split.right_optimal_dt, config.complexity_cost);
112113
upper_bound = std::min(upper_bound, current_best_score);
113-
if (current_best_score <= config.complexity_cost + EPSILON) return;
114114
if (PRINT_INTERMEDIARY_TIME_SOLUTIONS && config.is_root) {
115-
const auto stop = std::chrono::high_resolution_clock::now();
116-
const auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - starting_time);
117-
std::cout << "Time taken to get the misclassification score " << current_best_score << ": " << duration.count() / 1000.0 << " seconds" << std::endl;
115+
std::cout << "Time taken to get the misclassification score " << current_best_score
116+
<< ": " << std::fixed << std::setprecision(4) << config.stopwatch.TimeElapsedInSeconds() << " seconds ("
117+
<< (config.stats->total_number_of_general_solver_calls + config.stats->total_number_of_specialized_solver_calls) <<" expansions)" << std::endl;
118118
}
119+
if (current_best_score <= config.complexity_cost + EPSILON) return;
119120
}
120121
// Add the result to the interval pruner
121122
interval_pruner.add_result(split.mid, split.left_optimal_dt->objective, split.right_optimal_dt->objective);
@@ -141,26 +142,24 @@ void GeneralSolver::create_optimal_decision_tree(const Dataview& dataview, Confi
141142
}
142143
}
143144

144-
void GeneralSolver::solve_split(const Dataview& dataview, const Configuration& config, SplitInfo& split, std::shared_ptr<Tree>& current_optimal_tree, float upper_bound) {
145+
void GeneralSolver::solve_split(const Dataview& dataview, Configuration& config, SplitInfo& split, std::shared_ptr<Tree>& current_optimal_tree, float upper_bound) {
145146
const std::vector<Dataset::FeatureElement>& current_feature = dataview.get_sorted_dataset_feature(split.feature);
146147
const auto& possible_split_indices = dataview.get_possible_split_indices(split.feature);
147148

148149
if (config.max_depth == 2) {
149150
// If the maximum remaining depth is two, we use our special depth two solver
150-
statistics::total_number_of_specialized_solver_calls += 1;
151-
SpecializedSolver::get_best_left_right_scores(dataview, split.feature, split.split_point, split.threshold, split.left_optimal_dt, split.right_optimal_dt, upper_bound, config.complexity_cost);
151+
config.stats->total_number_of_specialized_solver_calls += 1;
152+
SpecializedSolver::get_best_left_right_scores(dataview, split, upper_bound, config.complexity_cost);
152153
RUNTIME_ASSERT(split.left_optimal_dt->objective >= 0, "D2 - Left tree should have non-negative misclassification score.");
153154
RUNTIME_ASSERT(split.right_optimal_dt->objective >= 0, "D2 - Right tree should have non-negative misclassification score.");
154155
} else {
155156
// We compute the distance from the split point to the sample indicies of the left and right interval split points
156157
const int interval_half_distance = std::max(split.split_point - possible_split_indices[split.left], possible_split_indices[split.right] - split.split_point);
157-
// The unique-value index of the split point. We use this to split the data (rather than the continuous threshold, which may give numerical instability)
158-
const int split_unique_value_index = current_feature[split.split_point].unique_value_index;
159158

160159
// Split the dataset in two based on the split point
161160
Dataview left_dataview = Dataview(dataview.get_class_number(), dataview.should_sort_by_gini_index());
162161
Dataview right_dataview = Dataview(dataview.get_class_number(), dataview.should_sort_by_gini_index());
163-
Dataview::split_data_points(dataview, split.feature, split.split_point, split_unique_value_index, left_dataview, right_dataview, config.max_depth);
162+
Dataview::split_data_points(dataview, split.feature, split.split_point, split.unique_value_threshold, left_dataview, right_dataview, config.max_depth);
164163

165164
// We first compute the subtree for the bigger dataset since it might make computing the smaller dataset obsolete
166165
auto& smaller_data = (left_dataview.get_dataset_size() < right_dataview.get_dataset_size()) ? left_dataview : right_dataview;
@@ -175,7 +174,7 @@ void GeneralSolver::solve_split(const Dataview& dataview, const Configuration& c
175174
: current_optimal_tree->objective;
176175

177176
// Recursively solve the subtree
178-
statistics::total_number_of_general_solver_calls += 1;
177+
config.stats->total_number_of_general_solver_calls += 1;
179178
Configuration left_solution_configuration = config.GetLeftSubtreeConfig();
180179
GeneralSolver::create_optimal_decision_tree(larger_data, left_solution_configuration, larger_optimal_dt, larger_ub);
181180
RUNTIME_ASSERT(larger_optimal_dt->objective >= 0, "Right tree should have non-negative misclassification score.");
@@ -189,7 +188,7 @@ void GeneralSolver::solve_split(const Dataview& dataview, const Configuration& c
189188
// We compute the second subtree only if we have a positive upper bound, larger than zero.
190189
// If the upper bound is precisely zero, but not because of the interval_half distance, we also want to compute the subproblem
191190
if (smaller_ub > 0 || (std::abs(smaller_ub) <= EPSILON && std::abs(smaller_obj_ub) <= EPSILON)) {
192-
statistics::total_number_of_general_solver_calls += 1;
191+
config.stats->total_number_of_general_solver_calls += 1;
193192
Configuration right_solution_configuration = config.GetRightSubtreeConfig(left_solution_configuration.max_gap);
194193
GeneralSolver::create_optimal_decision_tree(smaller_data, right_solution_configuration, smaller_optimal_dt, smaller_ub);
195194
RUNTIME_ASSERT(smaller_optimal_dt->objective >= 0, "Left tree should have non-negative misclassification score.");

code/Engine/src/specialized_solver.cpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "specialized_solver.h"
2+
#include "general_solver.h"
23

34
class Depth1ScoreHelper {
45
public:
@@ -43,19 +44,19 @@ class Depth1ScoreHelper {
4344
std::vector<int> current_label_frequency;
4445
};
4546

46-
void SpecializedSolver::get_best_left_right_scores(const Dataview& dataview, int feature_index, int split_point, float threshold, std::shared_ptr<Tree> &left_optimal_dt, std::shared_ptr<Tree> &right_optimal_dt, float upper_bound, float complexity_cost) {
47+
void SpecializedSolver::get_best_left_right_scores(const Dataview& dataview, SplitInfo& split, float upper_bound, float complexity_cost) {
4748
RUNTIME_ASSERT(dataview.get_dataset_size() > 0, "Dataset cannot be empty.");
4849
RUNTIME_ASSERT(upper_bound >= complexity_cost, "Upper bound should always be complexity-cost or higher.");
49-
RUNTIME_ASSERT(left_optimal_dt->objective >= 0, "Current objective should always be zero or higher.");
50-
RUNTIME_ASSERT(right_optimal_dt->objective >= 0, "Current objective should always be zero or higher.");
50+
RUNTIME_ASSERT(split.left_optimal_dt->objective >= 0, "Current objective should always be zero or higher.");
51+
RUNTIME_ASSERT(split.right_optimal_dt->objective >= 0, "Current objective should always be zero or higher.");
5152

52-
const auto& split_feature = dataview.get_sorted_dataset_feature(feature_index);
53-
const auto& unsorted_split_feature = dataview.get_unsorted_dataset_feature(feature_index);
53+
const auto& split_feature = dataview.get_sorted_dataset_feature(split.feature);
54+
const auto& unsorted_split_feature = dataview.get_unsorted_dataset_feature(split.feature);
5455
std::vector<int> split_feature_split_indices(unsorted_split_feature.size());
5556
int split_index = -1;
5657
for (const auto& split_feature_data : split_feature) {
5758
split_feature_split_indices[split_feature_data.data_point_index] = split_feature_data.unique_value_index;
58-
if (split_index == -1 && split_feature_data.value >= threshold) {
59+
if (split_index == -1 && split_feature_data.unique_value_index >= split.unique_value_threshold) {
5960
split_index = split_feature_data.unique_value_index;
6061
}
6162
}
@@ -64,15 +65,15 @@ void SpecializedSolver::get_best_left_right_scores(const Dataview& dataview, int
6465
const int dataset_size = dataview.get_dataset_size();
6566
const int class_number = dataview.get_class_number();
6667

67-
RUNTIME_ASSERT(split_point > 0 && split_point < dataset_size, "left and right subtree need to be non-empty.");
68-
Depth1ScoreHelper left_tree(split_point, class_number);
69-
Depth1ScoreHelper right_tree(dataset_size - split_point, class_number);
68+
RUNTIME_ASSERT(split.split_point > 0 && split.split_point < dataset_size, "left and right subtree need to be non-empty.");
69+
Depth1ScoreHelper left_tree(split.split_point, class_number);
70+
Depth1ScoreHelper right_tree(dataset_size - split.split_point, class_number);
7071

7172

7273
left_tree.classification_score = std::max(0.0f, float(left_tree.size) - upper_bound);
7374
right_tree.classification_score = std::max(0.0f, float(right_tree.size) - upper_bound);
7475

75-
Dataview::initialize_split_parameters(split_feature, class_number, dataview.get_label_frequency(), split_point, left_tree.label_frequency, right_tree.label_frequency);
76+
Dataview::initialize_split_parameters(split_feature, class_number, dataview.get_label_frequency(), split.split_point, left_tree.label_frequency, right_tree.label_frequency);
7677

7778
left_tree.max_label_frequency = 0;
7879
right_tree.max_label_frequency = 0;
@@ -91,38 +92,38 @@ void SpecializedSolver::get_best_left_right_scores(const Dataview& dataview, int
9192

9293
for (int current_feature_index = 0; current_feature_index < dataview.get_feature_number(); current_feature_index++) {
9394
if (left_tree.classification_score + right_tree.classification_score == dataset_size) break;
94-
if (current_feature_index == feature_index) {
95-
process_depth_one_feature<true>(dataview, feature_index, split_point, current_feature_index, split_index,
95+
if (current_feature_index == split.feature) {
96+
process_depth_one_feature<true>(dataview, split.feature, split.split_point, current_feature_index, split_index,
9697
left_tree, right_tree, split_feature_split_indices, upper_bound, complexity_cost);
9798
} else {
98-
process_depth_one_feature<false>(dataview, feature_index, split_point, current_feature_index, split_index,
99+
process_depth_one_feature<false>(dataview, split.feature, split.split_point, current_feature_index, split_index,
99100
left_tree, right_tree, split_feature_split_indices, upper_bound, complexity_cost);
100101
}
101102
}
102103

103104
if (left_tree.is_leaf()) {
104-
left_optimal_dt->make_leaf(left_tree.max_label, left_tree.size - left_tree.classification_score);
105+
split.left_optimal_dt->make_leaf(left_tree.max_label, left_tree.size - left_tree.classification_score);
105106
} else {
106-
left_optimal_dt->update_split(left_tree.best_feature_index, left_tree.best_threshold,
107+
split.left_optimal_dt->update_split(left_tree.best_feature_index, left_tree.best_threshold,
107108
std::make_shared<Tree>(left_tree.best_left_label, -1),
108109
std::make_shared<Tree>(left_tree.best_right_label, -1), complexity_cost);
109-
left_optimal_dt->objective = left_tree.get_objective(complexity_cost);
110+
split.left_optimal_dt->objective = left_tree.get_objective(complexity_cost);
110111
//RUNTIME_ASSERT(left_tree.best_left_label != -1, "Left tree left label should be initialized.");
111112
//RUNTIME_ASSERT(left_tree.best_right_label != -1, "Left tree right label should be initialized.");
112113
}
113-
RUNTIME_ASSERT(left_optimal_dt->objective >= 0, "LR - Left tree misclassification score should be non-negative.");
114+
RUNTIME_ASSERT(split.left_optimal_dt->objective >= 0, "LR - Left tree misclassification score should be non-negative.");
114115

115116
if (right_tree.is_leaf()) {
116-
right_optimal_dt->make_leaf(right_tree.max_label, right_tree.size - right_tree.classification_score);
117+
split.right_optimal_dt->make_leaf(right_tree.max_label, right_tree.size - right_tree.classification_score);
117118
} else {
118-
right_optimal_dt->update_split(right_tree.best_feature_index, right_tree.best_threshold,
119+
split.right_optimal_dt->update_split(right_tree.best_feature_index, right_tree.best_threshold,
119120
std::make_shared<Tree>(right_tree.best_left_label, -1),
120121
std::make_shared<Tree>(right_tree.best_right_label, -1), complexity_cost);
121-
right_optimal_dt->objective = right_tree.get_objective(complexity_cost);
122+
split.right_optimal_dt->objective = right_tree.get_objective(complexity_cost);
122123
//RUNTIME_ASSERT(right_tree.best_left_label != -1, "Right tree left label should be initialized.");
123124
//RUNTIME_ASSERT(right_tree.best_right_label != -1, "Right tree right label should be initialized.");
124125
}
125-
RUNTIME_ASSERT(right_optimal_dt->objective >= 0, "LR - Right tree misclassification score should be non-negative.");
126+
RUNTIME_ASSERT(split.right_optimal_dt->objective >= 0, "LR - Right tree misclassification score should be non-negative.");
126127
}
127128

128129
template <bool is_same_feature>
@@ -159,11 +160,13 @@ void SpecializedSolver::process_depth_one_feature(const Dataview& dataview,
159160

160161
tree.can_skip--;
161162

162-
if (current_feature_data.unique_value_index == tree.previous_unique_value_index || tree.can_skip > 0){
163+
if (current_feature_data.unique_value_index == tree.previous_unique_value_index || tree.can_skip > 0) {
163164
tree.current_element_count++;
164165
tree.current_label_frequency[current_feature_data.label]++;
165166
tree.previous_value = current_feature_data.value;
166167
tree.previous_unique_value_index = current_feature_data.unique_value_index;
168+
RUNTIME_ASSERT(tree.current_label_frequency[current_feature_data.label] <= tree.label_frequency[current_feature_data.label]
169+
, "Current label frequency can never exceed the maximum label frequency.");
167170
continue;
168171
}
169172

code/Utilities/include/configuration.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
11
#ifndef CONFIGURATION_H
22
#define CONFIGURATION_H
33

4-
#include <chrono>
54
#include "stopwatch.h"
5+
#include "statistics.h"
66
#include <cassert>
77
#include <stdexcept>
88
#include <iostream>
99

1010
#define EPSILON 0.0000001f
1111

1212
#define PRINT_INTERMEDIARY_TIME_SOLUTIONS 0
13-
extern std::chrono::high_resolution_clock::time_point starting_time;
14-
15-
16-
1713

1814
#ifdef NDEBUG
1915
#define RUNTIME_ASSERT(cond, msg) ((void)0)
@@ -31,6 +27,9 @@ extern std::chrono::high_resolution_clock::time_point starting_time;
3127
#endif
3228

3329
struct Configuration {
30+
31+
Configuration() : stats(std::make_shared<statistics>()) {}
32+
3433
int max_depth{ 3 };
3534
int max_gap{ 0 };
3635
float max_gap_decay{ 0.0 };
@@ -40,6 +39,7 @@ struct Configuration {
4039
bool sort_gini{ false };
4140
float complexity_cost{ 0.0 };
4241
Stopwatch stopwatch;
42+
std::shared_ptr<statistics> stats{ nullptr };
4343
Configuration GetLeftSubtreeConfig() const;
4444
Configuration GetRightSubtreeConfig(int left_gap) const;
4545
};

code/Utilities/include/statistics.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
class statistics {
55
public:
6-
static unsigned long long total_number_of_specialized_solver_calls;
7-
static unsigned long long total_number_of_general_solver_calls;
8-
static unsigned long long total_number_cache_hits;
6+
unsigned long long total_number_of_specialized_solver_calls{ 0 };
7+
unsigned long long total_number_of_general_solver_calls{ 0 };
8+
unsigned long long total_number_cache_hits{ 0 };
99

10-
static bool should_print;
10+
bool should_print{ false };
1111

12-
static void print_statistics();
12+
void print_statistics();
1313
};
1414

1515
#endif // STATISTICS_H

0 commit comments

Comments
 (0)