41 #include <pcl/ml/branch_estimator.h>
42 #include <pcl/ml/stats_estimator.h>
50 template <
class FeatureType,
class LabelType>
66 feature.serialize(stream);
68 stream.write(
reinterpret_cast<const char*
>(&threshold),
sizeof(threshold));
70 stream.write(
reinterpret_cast<const char*
>(&value),
sizeof(value));
71 stream.write(
reinterpret_cast<const char*
>(&variance),
sizeof(variance));
73 const int num_of_sub_nodes =
static_cast<int>(sub_nodes.size());
74 stream.write(
reinterpret_cast<const char*
>(&num_of_sub_nodes),
75 sizeof(num_of_sub_nodes));
76 for (
int sub_node_index = 0; sub_node_index < num_of_sub_nodes; ++sub_node_index) {
77 sub_nodes[sub_node_index].serialize(stream);
88 feature.deserialize(stream);
90 stream.read(
reinterpret_cast<char*
>(&threshold),
sizeof(threshold));
92 stream.read(
reinterpret_cast<char*
>(&value),
sizeof(value));
93 stream.read(
reinterpret_cast<char*
>(&variance),
sizeof(variance));
96 stream.read(
reinterpret_cast<char*
>(&num_of_sub_nodes),
sizeof(num_of_sub_nodes));
97 sub_nodes.resize(num_of_sub_nodes);
99 if (num_of_sub_nodes > 0) {
100 for (
int sub_node_index = 0; sub_node_index < num_of_sub_nodes;
102 sub_nodes[sub_node_index].deserialize(stream);
125 template <
class LabelDataType,
class NodeType,
class DataSet,
class ExampleIndex>
131 : branch_estimator_(branch_estimator)
142 return branch_estimator_->getNumOfBranches();
166 std::vector<ExampleIndex>& examples,
167 std::vector<LabelDataType>& label_data,
168 std::vector<float>& results,
169 std::vector<unsigned char>& flags,
170 const float threshold)
const
172 const std::size_t num_of_examples = examples.size();
173 const std::size_t num_of_branches = getNumOfBranches();
176 std::vector<LabelDataType> sums(num_of_branches + 1, 0);
177 std::vector<LabelDataType> sqr_sums(num_of_branches + 1, 0);
178 std::vector<std::size_t> branch_element_count(num_of_branches + 1, 0);
180 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
181 branch_element_count[branch_index] = 1;
182 ++branch_element_count[num_of_branches];
185 for (std::size_t example_index = 0; example_index < num_of_examples;
187 unsigned char branch_index;
189 results[example_index], flags[example_index], threshold, branch_index);
191 LabelDataType label = label_data[example_index];
193 sums[branch_index] += label;
194 sums[num_of_branches] += label;
196 sqr_sums[branch_index] += label * label;
197 sqr_sums[num_of_branches] += label * label;
199 ++branch_element_count[branch_index];
200 ++branch_element_count[num_of_branches];
203 std::vector<float> variances(num_of_branches + 1, 0);
204 for (std::size_t branch_index = 0; branch_index < num_of_branches + 1;
206 const float mean_sum =
207 static_cast<float>(sums[branch_index]) / branch_element_count[branch_index];
208 const float mean_sqr_sum =
static_cast<float>(sqr_sums[branch_index]) /
209 branch_element_count[branch_index];
210 variances[branch_index] = mean_sqr_sum - mean_sum * mean_sum;
213 float information_gain = variances[num_of_branches];
214 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
217 const float weight =
static_cast<float>(branch_element_count[branch_index]) /
218 static_cast<float>(branch_element_count[num_of_branches]);
219 information_gain -= weight * variances[branch_index];
222 return information_gain;
234 std::vector<unsigned char>& flags,
235 const float threshold,
236 std::vector<unsigned char>& branch_indices)
const
238 const std::size_t num_of_results = results.size();
239 const std::size_t num_of_branches = getNumOfBranches();
241 branch_indices.resize(num_of_results);
242 for (std::size_t result_index = 0; result_index < num_of_results; ++result_index) {
243 unsigned char branch_index;
245 results[result_index], flags[result_index], threshold, branch_index);
246 branch_indices[result_index] = branch_index;
259 const unsigned char flag,
260 const float threshold,
261 unsigned char& branch_index)
const
263 branch_estimator_->computeBranchIndex(result, flag, threshold, branch_index);
277 std::vector<ExampleIndex>& examples,
278 std::vector<LabelDataType>& label_data,
279 NodeType& node)
const
281 const std::size_t num_of_examples = examples.size();
283 LabelDataType sum = 0.0f;
284 LabelDataType sqr_sum = 0.0f;
285 for (std::size_t example_index = 0; example_index < num_of_examples;
287 const LabelDataType label = label_data[example_index];
290 sqr_sum += label * label;
293 sum /= num_of_examples;
294 sqr_sum /= num_of_examples;
296 const float variance = sqr_sum - sum * sum;
299 node.variance = variance;
310 stream <<
"ERROR: RegressionVarianceStatsEstimator does not implement "
311 "generateCodeForBranchIndex(...)";
322 stream <<
"ERROR: RegressionVarianceStatsEstimator does not implement "
323 "generateCodeForBranchIndex(...)";
Interface for branch estimators.
Node for a regression trees which optimizes variance.
RegressionVarianceNode()
Constructor.
void serialize(std::ostream &stream) const
Serializes the node to the specified stream.
LabelType variance
The variance of the labels that ended up at this node during training.
void deserialize(std::istream &stream)
Deserializes a node from the specified stream.
float threshold
The threshold applied on the feature response.
FeatureType feature
The feature associated with the node.
LabelType value
The label value of this node.
std::vector< RegressionVarianceNode > sub_nodes
The child nodes.
virtual ~RegressionVarianceNode()
Destructor.
Statistics estimator for regression trees which optimizes variance.
void generateCodeForOutput(NodeType &node, std::ostream &stream) const
Generates code for label output.
void computeAndSetNodeStats(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, NodeType &node) const
Computes and sets the statistics for a node.
void computeBranchIndex(const float result, const unsigned char flag, const float threshold, unsigned char &branch_index) const
Computes the branch index for the specified result.
LabelDataType getLabelOfNode(NodeType &node) const
Returns the label of the specified node.
void computeBranchIndices(std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold, std::vector< unsigned char > &branch_indices) const
Computes the branch indices for all supplied results.
virtual ~RegressionVarianceStatsEstimator()
Destructor.
RegressionVarianceStatsEstimator(BranchEstimator *branch_estimator)
Constructor.
float computeInformationGain(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold) const
Computes the information gain obtained by the specified threshold.
std::size_t getNumOfBranches() const
Returns the number of branches the corresponding tree has.
void generateCodeForBranchIndexComputation(NodeType &node, std::ostream &stream) const
Generates code for branch index computation.
Class interface for gathering statistics for decision tree learning.
Define standard C methods and C++ classes that are common to all methods.