42 template <
class FeatureType,
49 , num_of_features_(1000)
50 , num_of_thresholds_(10)
51 , feature_handler_(nullptr)
52 , stats_estimator_(nullptr)
58 template <
class FeatureType,
66 template <
class FeatureType,
75 const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
76 const std::size_t num_of_examples = examples_.size();
79 std::vector<FeatureType> features;
80 feature_handler_->createRandomFeatures(num_of_features_, features);
86 std::vector<std::vector<float>> feature_results(num_of_features_);
87 std::vector<std::vector<unsigned char>> flags(num_of_features_);
89 for (std::size_t feature_index = 0; feature_index < num_of_features_;
91 feature_results[feature_index].reserve(num_of_examples);
92 flags[feature_index].reserve(num_of_examples);
94 feature_handler_->evaluateFeature(features[feature_index],
97 feature_results[feature_index],
98 flags[feature_index]);
102 std::vector<std::vector<std::vector<float>>> branch_feature_results(
104 std::vector<std::vector<std::vector<unsigned char>>> branch_flags(
106 std::vector<std::vector<std::vector<ExampleIndex>>> branch_examples(
108 std::vector<std::vector<std::vector<LabelType>>> branch_label_data(
112 for (std::size_t feature_index = 0; feature_index < num_of_features_;
114 branch_feature_results[feature_index].resize(1);
115 branch_flags[feature_index].resize(1);
116 branch_examples[feature_index].resize(1);
117 branch_label_data[feature_index].resize(1);
119 branch_feature_results[feature_index][0] = feature_results[feature_index];
120 branch_flags[feature_index][0] = flags[feature_index];
121 branch_examples[feature_index][0] = examples_;
122 branch_label_data[feature_index][0] = label_data_;
125 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
127 std::vector<std::vector<float>> thresholds(num_of_features_);
129 for (std::size_t feature_index = 0; feature_index < num_of_features_;
131 thresholds.reserve(num_of_thresholds_);
132 createThresholdsUniform(num_of_thresholds_,
133 feature_results[feature_index],
134 thresholds[feature_index]);
138 int best_feature_index = -1;
139 float best_feature_threshold = 0.0f;
140 float best_feature_information_gain = 0.0f;
142 for (std::size_t feature_index = 0; feature_index < num_of_features_;
144 for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
146 float information_gain = 0.0f;
147 for (std::size_t branch_index = 0;
148 branch_index < branch_feature_results[feature_index].size();
150 const float branch_information_gain =
151 stats_estimator_->computeInformationGain(
153 branch_examples[feature_index][branch_index],
154 branch_label_data[feature_index][branch_index],
155 branch_feature_results[feature_index][branch_index],
156 branch_flags[feature_index][branch_index],
157 thresholds[feature_index][threshold_index]);
160 branch_information_gain *
161 branch_feature_results[feature_index][branch_index].size();
164 if (information_gain > best_feature_information_gain) {
165 best_feature_information_gain = information_gain;
166 best_feature_index =
static_cast<int>(feature_index);
167 best_feature_threshold = thresholds[feature_index][threshold_index];
173 fern.
accessFeature(depth_index) = features[best_feature_index];
177 for (std::size_t feature_index = 0; feature_index < num_of_features_;
179 std::vector<std::vector<float>>& cur_branch_feature_results =
180 branch_feature_results[feature_index];
181 std::vector<std::vector<unsigned char>>& cur_branch_flags =
182 branch_flags[feature_index];
183 std::vector<std::vector<ExampleIndex>>& cur_branch_examples =
184 branch_examples[feature_index];
185 std::vector<std::vector<LabelType>>& cur_branch_label_data =
186 branch_label_data[feature_index];
188 const std::size_t total_num_of_new_branches =
189 num_of_branches * cur_branch_feature_results.size();
191 std::vector<std::vector<float>> new_branch_feature_results(
192 total_num_of_new_branches);
193 std::vector<std::vector<unsigned char>> new_branch_flags(
194 total_num_of_new_branches);
195 std::vector<std::vector<ExampleIndex>> new_branch_examples(
196 total_num_of_new_branches);
197 std::vector<std::vector<LabelType>> new_branch_label_data(
198 total_num_of_new_branches);
200 for (std::size_t branch_index = 0;
201 branch_index < cur_branch_feature_results.size();
203 const std::size_t num_of_examples_in_this_branch =
204 cur_branch_feature_results[branch_index].size();
206 std::vector<unsigned char> branch_indices;
207 branch_indices.reserve(num_of_examples_in_this_branch);
209 stats_estimator_->computeBranchIndices(cur_branch_feature_results[branch_index],
210 cur_branch_flags[branch_index],
211 best_feature_threshold,
215 const std::size_t base_branch_index = branch_index * num_of_branches;
216 for (std::size_t example_index = 0;
217 example_index < num_of_examples_in_this_branch;
219 const std::size_t combined_branch_index =
220 base_branch_index + branch_indices[example_index];
222 new_branch_feature_results[combined_branch_index].push_back(
223 cur_branch_feature_results[branch_index][example_index]);
224 new_branch_flags[combined_branch_index].push_back(
225 cur_branch_flags[branch_index][example_index]);
226 new_branch_examples[combined_branch_index].push_back(
227 cur_branch_examples[branch_index][example_index]);
228 new_branch_label_data[combined_branch_index].push_back(
229 cur_branch_label_data[branch_index][example_index]);
233 branch_feature_results[feature_index] = new_branch_feature_results;
234 branch_flags[feature_index] = new_branch_flags;
235 branch_examples[feature_index] = new_branch_examples;
236 branch_label_data[feature_index] = new_branch_label_data;
242 std::vector<std::vector<float>> final_feature_results(
244 std::vector<std::vector<unsigned char>> final_flags(
246 std::vector<std::vector<unsigned char>> final_branch_indices(
248 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
249 final_feature_results[depth_index].reserve(num_of_examples);
250 final_flags[depth_index].reserve(num_of_examples);
251 final_branch_indices[depth_index].reserve(num_of_examples);
253 feature_handler_->evaluateFeature(fern.
accessFeature(depth_index),
256 final_feature_results[depth_index],
257 final_flags[depth_index]);
259 stats_estimator_->computeBranchIndices(final_feature_results[depth_index],
260 final_flags[depth_index],
262 final_branch_indices[depth_index]);
266 std::vector<std::vector<LabelType>> node_labels(
268 std::vector<std::vector<ExampleIndex>> node_examples(
271 for (std::size_t example_index = 0; example_index < num_of_examples;
273 std::size_t node_index = 0;
274 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
275 node_index *= num_of_branches;
276 node_index += final_branch_indices[depth_index][example_index];
279 node_labels[node_index].push_back(label_data_[example_index]);
280 node_examples[node_index].push_back(examples_[example_index]);
284 const std::size_t num_of_nodes = 0x1 << fern_depth_;
285 for (std::size_t node_index = 0; node_index < num_of_nodes; ++node_index) {
286 stats_estimator_->computeAndSetNodeStats(data_set_,
287 node_examples[node_index],
288 node_labels[node_index],
293 template <
class FeatureType,
301 std::vector<float>& values,
302 std::vector<float>& thresholds)
305 float min_value = ::std::numeric_limits<float>::max();
306 float max_value = -::std::numeric_limits<float>::max();
308 const std::size_t num_of_values = values.size();
309 for (
int value_index = 0; value_index < num_of_values; ++value_index) {
310 const float value = values[value_index];
312 if (value < min_value)
314 if (value > max_value)
318 const float range = max_value - min_value;
319 const float step = range / (num_of_thresholds + 2);
322 thresholds.resize(num_of_thresholds);
324 for (
int threshold_index = 0; threshold_index < num_of_thresholds;
326 thresholds[threshold_index] = min_value + step * (threshold_index + 1);
Class representing a Fern.
float & accessThreshold(const std::size_t threshold_index)
Access operator for thresholds.
void initialize(const std::size_t num_of_decisions)
Initializes the fern.
FeatureType & accessFeature(const std::size_t feature_index)
Access operator for features.
static void createThresholdsUniform(const std::size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformely distrebuted thresholds over the range of the supplied values.
virtual ~FernTrainer()
Destructor.
void train(Fern< FeatureType, NodeType > &fern)
Trains a decision tree using the set training data and settings.
FernTrainer()
Constructor.