40 #include <pcl/common/common.h>
41 #include <pcl/ml/branch_estimator.h>
42 #include <pcl/ml/stats_estimator.h>
50 template <
class FeatureType,
class LabelType>
63 feature.serialize(stream);
65 stream.write(reinterpret_cast<const char*>(&threshold),
sizeof(threshold));
67 stream.write(reinterpret_cast<const char*>(&value),
sizeof(value));
68 stream.write(reinterpret_cast<const char*>(&variance),
sizeof(variance));
70 const int num_of_sub_nodes =
static_cast<int>(sub_nodes.size());
71 stream.write(reinterpret_cast<const char*>(&num_of_sub_nodes),
72 sizeof(num_of_sub_nodes));
73 for (
int sub_node_index = 0; sub_node_index < num_of_sub_nodes; ++sub_node_index) {
74 sub_nodes[sub_node_index].serialize(stream);
85 feature.deserialize(stream);
87 stream.read(reinterpret_cast<char*>(&threshold),
sizeof(threshold));
89 stream.read(reinterpret_cast<char*>(&value),
sizeof(value));
90 stream.read(reinterpret_cast<char*>(&variance),
sizeof(variance));
93 stream.read(reinterpret_cast<char*>(&num_of_sub_nodes),
sizeof(num_of_sub_nodes));
94 sub_nodes.resize(num_of_sub_nodes);
96 if (num_of_sub_nodes > 0) {
97 for (
int sub_node_index = 0; sub_node_index < num_of_sub_nodes;
99 sub_nodes[sub_node_index].deserialize(stream);
122 template <
class LabelDataType,
class NodeType,
class DataSet,
class ExampleIndex>
128 : branch_estimator_(branch_estimator)
136 return branch_estimator_->getNumOfBranches();
160 std::vector<ExampleIndex>& examples,
161 std::vector<LabelDataType>& label_data,
162 std::vector<float>& results,
163 std::vector<unsigned char>& flags,
164 const float threshold)
const
166 const std::size_t num_of_examples = examples.size();
167 const std::size_t num_of_branches = getNumOfBranches();
170 std::vector<LabelDataType> sums(num_of_branches + 1, 0);
171 std::vector<LabelDataType> sqr_sums(num_of_branches + 1, 0);
172 std::vector<std::size_t> branch_element_count(num_of_branches + 1, 0);
174 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
175 branch_element_count[branch_index] = 1;
176 ++branch_element_count[num_of_branches];
179 for (std::size_t example_index = 0; example_index < num_of_examples;
181 unsigned char branch_index;
183 results[example_index], flags[example_index], threshold, branch_index);
185 LabelDataType label = label_data[example_index];
187 sums[branch_index] += label;
188 sums[num_of_branches] += label;
190 sqr_sums[branch_index] += label * label;
191 sqr_sums[num_of_branches] += label * label;
193 ++branch_element_count[branch_index];
194 ++branch_element_count[num_of_branches];
197 std::vector<float> variances(num_of_branches + 1, 0);
198 for (std::size_t branch_index = 0; branch_index < num_of_branches + 1;
200 const float mean_sum =
201 static_cast<float>(sums[branch_index]) / branch_element_count[branch_index];
202 const float mean_sqr_sum =
static_cast<float>(sqr_sums[branch_index]) /
203 branch_element_count[branch_index];
204 variances[branch_index] = mean_sqr_sum - mean_sum * mean_sum;
207 float information_gain = variances[num_of_branches];
208 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
211 const float weight =
static_cast<float>(branch_element_count[branch_index]) /
212 static_cast<float>(branch_element_count[num_of_branches]);
213 information_gain -= weight * variances[branch_index];
216 return information_gain;
228 std::vector<unsigned char>& flags,
229 const float threshold,
230 std::vector<unsigned char>& branch_indices)
const
232 const std::size_t num_of_results = results.size();
233 const std::size_t num_of_branches = getNumOfBranches();
235 branch_indices.resize(num_of_results);
236 for (std::size_t result_index = 0; result_index < num_of_results; ++result_index) {
237 unsigned char branch_index;
239 results[result_index], flags[result_index], threshold, branch_index);
240 branch_indices[result_index] = branch_index;
253 const unsigned char flag,
254 const float threshold,
255 unsigned char& branch_index)
const
257 branch_estimator_->computeBranchIndex(result, flag, threshold, branch_index);
271 std::vector<ExampleIndex>& examples,
272 std::vector<LabelDataType>& label_data,
273 NodeType& node)
const
275 const std::size_t num_of_examples = examples.size();
277 LabelDataType sum = 0.0f;
278 LabelDataType sqr_sum = 0.0f;
279 for (std::size_t example_index = 0; example_index < num_of_examples;
281 const LabelDataType label = label_data[example_index];
284 sqr_sum += label * label;
287 sum /= num_of_examples;
288 sqr_sum /= num_of_examples;
290 const float variance = sqr_sum - sum * sum;
293 node.variance = variance;
304 stream <<
"ERROR: RegressionVarianceStatsEstimator does not implement "
305 "generateCodeForBranchIndex(...)";
316 stream <<
"ERROR: RegressionVarianceStatsEstimator does not implement "
317 "generateCodeForBranchIndex(...)";
Statistics estimator for regression trees which optimizes variance.
LabelType value
The label value of this node.
Node for a regression trees which optimizes variance.
void computeAndSetNodeStats(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, NodeType &node) const
Computes and sets the statistics for a node.
LabelType variance
The variance of the labels that ended up at this node during training.
FeatureType feature
The feature associated with the node.
void serialize(std::ostream &stream) const
Serializes the node to the specified stream.
float threshold
The threshold applied on the feature response.
void computeBranchIndex(const float result, const unsigned char flag, const float threshold, unsigned char &branch_index) const
Computes the branch index for the specified result.
RegressionVarianceNode()
Constructor.
Class interface for gathering statistics for decision tree learning.
void computeBranchIndices(std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold, std::vector< unsigned char > &branch_indices) const
Computes the branch indices for all supplied results.
RegressionVarianceStatsEstimator(BranchEstimator *branch_estimator)
Constructor.
std::vector< RegressionVarianceNode > sub_nodes
The child nodes.
void deserialize(std::istream &stream)
Deserializes a node from the specified stream.
void generateCodeForBranchIndexComputation(NodeType &node, std::ostream &stream) const
Generates code for branch index computation.
LabelDataType getLabelOfNode(NodeType &node) const
Returns the label of the specified node.
void generateCodeForOutput(NodeType &node, std::ostream &stream) const
Generates code for label output.
float computeInformationGain(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold) const
Computes the information gain obtained by the specified threshold.
Interface for branch estimators.
std::size_t getNumOfBranches() const
Returns the number of branches the corresponding tree has.