-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
optimize CPU inference with Array-Based Tree Traversal #11519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 28 commits
e64e20c
60c2ffe
8f6dfe3
bd13491
3827a49
7334bd2
8356855
165b34a
52eee0c
cb28530
7ae3a42
2799644
a8bb91e
6d94176
e34becc
a2f2c75
2afad25
92ac69e
e2b0f05
87bee15
c3c1c85
d270ee7
2a7e575
3539ec0
709d233
40be7e2
c9160c6
de552e8
9c1007f
92b5069
b0eaa85
790a98e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,281 @@ | ||||||
/** | ||||||
* Copyright 2021-2025, XGBoost Contributors | ||||||
* \file array_tree_layout.cc | ||||||
* \brief Implementation of array tree layout -- a powerfull inference optimization method. | ||||||
*/ | ||||||
#ifndef XGBOOST_PREDICTOR_ARRAY_TREE_LAYOUT_H_ | ||||||
#define XGBOOST_PREDICTOR_ARRAY_TREE_LAYOUT_H_ | ||||||
|
||||||
#include <limits> | ||||||
#include <vector> | ||||||
|
||||||
namespace xgboost::predictor { | ||||||
|
||||||
/** | ||||||
* @brief The class holds the array-based representation of the top levels of a single tree. | ||||||
* | ||||||
* @tparam TreeType The type of the origianl tree (RegTree or MultiTargetTree) | ||||||
* | ||||||
* @tparam has_categorical if the tree has categorical features | ||||||
* | ||||||
* @tparam any_missing if the class is able to process missing values | ||||||
* | ||||||
* @tparam kNumDeepLevels number of tree leveles being unrolled into array-based structure | ||||||
*/ | ||||||
template <class TreeType, bool has_categorical, bool any_missing, int kNumDeepLevels> | ||||||
class ArrayTreeLayout { | ||||||
private: | ||||||
/* Number of nodes in the array based representation of the top levels of the tree | ||||||
*/ | ||||||
constexpr static size_t kNodesCount = (1u << kNumDeepLevels) - 1; | ||||||
|
||||||
struct Empty {}; | ||||||
using DefaultLeftType = | ||||||
typename std::conditional_t<any_missing, | ||||||
std::array<uint8_t, kNodesCount>, | ||||||
struct Empty>; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
using IsCatType = | ||||||
typename std::conditional_t<has_categorical, | ||||||
std::array<uint8_t, kNodesCount>, | ||||||
struct Empty>; | ||||||
using CatSegmentType = | ||||||
typename std::conditional_t<has_categorical, | ||||||
std::array<common::Span<uint32_t const>, kNodesCount>, | ||||||
struct Empty>; | ||||||
|
||||||
DefaultLeftType default_left_; | ||||||
IsCatType is_cat_; | ||||||
CatSegmentType cat_segment_; | ||||||
|
||||||
std::array<bst_feature_t, kNodesCount> split_index_; | ||||||
std::array<float, kNodesCount> split_cond_; | ||||||
/* The nodes at tree levels 0, 1, ..., kNumDeepLevels - 1 are unrolled into an array-based structure. | ||||||
* If the tree has additional levels, this array stores the node indices of the sub-trees at level kNumDeepLevels. | ||||||
* This is necessary to continue processing nodes that are not eligible for array-based unrolling. | ||||||
* The number of sub-trees packed into this array is equal to the number of nodes at tree level kNumDeepLevels, | ||||||
* which is calculated as (1u << kNumDeepLevels) == kNodesCount + 1. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if the tree is not well balanced and is more like a linked list than a tree? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case we add dummy nodes with So with array layout we have to allocated all nodes, but keep some of the unpopulated in case the tree is pure balanced. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for sharing, that makes sense. |
||||||
*/ | ||||||
std::array<bst_node_t, kNodesCount + 1> nidx_in_tree_; | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
inline static bool IsLeaf(const RegTree& tree, bst_node_t nidx) { | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
static_assert(std::is_same_v<RegTree, TreeType>); | ||||||
return tree[nidx].IsLeaf(); | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
} | ||||||
|
||||||
inline static bool IsLeaf(const MultiTargetTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<MultiTargetTree, TreeType>); | ||||||
return tree.IsLeaf(nidx); | ||||||
} | ||||||
|
||||||
inline static uint8_t DefaultLeft(const RegTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<RegTree, TreeType>); | ||||||
return tree[nidx].DefaultLeft(); | ||||||
} | ||||||
|
||||||
inline static uint8_t DefaultLeft(const MultiTargetTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<MultiTargetTree, TreeType>); | ||||||
return tree.DefaultLeft(nidx); | ||||||
} | ||||||
|
||||||
inline static bst_feature_t SplitIndex(const RegTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<RegTree, TreeType>); | ||||||
return tree[nidx].SplitIndex(); | ||||||
} | ||||||
|
||||||
inline static bst_feature_t SplitIndex(const MultiTargetTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<MultiTargetTree, TreeType>); | ||||||
return tree.SplitIndex(nidx); | ||||||
} | ||||||
|
||||||
inline static float SplitCond(const RegTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<RegTree, TreeType>); | ||||||
return tree[nidx].SplitCond(); | ||||||
} | ||||||
|
||||||
inline static float SplitCond(const MultiTargetTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<MultiTargetTree, TreeType>); | ||||||
return tree.SplitCond(nidx); | ||||||
} | ||||||
|
||||||
inline static bst_node_t LeftChild(const RegTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<RegTree, TreeType>); | ||||||
return tree[nidx].LeftChild(); | ||||||
} | ||||||
|
||||||
inline static bst_node_t LeftChild(const MultiTargetTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<MultiTargetTree, TreeType>); | ||||||
return tree.LeftChild(nidx); | ||||||
} | ||||||
|
||||||
inline static bst_node_t RightChild(const RegTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<RegTree, TreeType>); | ||||||
return tree[nidx].LeftChild() + 1; | ||||||
} | ||||||
|
||||||
inline static bst_node_t RightChild(const MultiTargetTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<MultiTargetTree, TreeType>); | ||||||
return tree.RightChild(nidx); | ||||||
} | ||||||
|
||||||
/** | ||||||
* @brief Traverse the top levels of original tree and fill internal arrays | ||||||
* | ||||||
* @tparam TreeType The type of the origianl tree (RegTree or MultiTargetTree) | ||||||
* | ||||||
* @tparam depth the tree level being processing | ||||||
* | ||||||
* @param tree the original tree | ||||||
* | ||||||
* @param cats matrix of categorical splits | ||||||
* | ||||||
* @param nidx_array node idx in the array layout | ||||||
* | ||||||
* @param nidx node idx in the original tree | ||||||
* | ||||||
*/ | ||||||
template <int depth = 0> | ||||||
void inline Populate(const TreeType& tree, RegTree::CategoricalSplitMatrix const &cats, | ||||||
bst_node_t nidx_array = 0, bst_node_t nidx = 0) { | ||||||
if constexpr (depth == kNumDeepLevels + 1) { | ||||||
return; | ||||||
} else if constexpr (depth == kNumDeepLevels) { | ||||||
/* We store the node index in the original tree to ensure continued processing | ||||||
* for nodes that are not eligible for array layout optimization. | ||||||
*/ | ||||||
nidx_in_tree_[nidx_array - kNodesCount] = nidx; | ||||||
} else { | ||||||
if (IsLeaf(tree, nidx)) { | ||||||
split_index_[nidx_array] = 0; | ||||||
|
||||||
/* | ||||||
* If the tree is not fully populated, we can reduce transfer costs. | ||||||
* The values for the unpopulated parts of the tree are set to ensure | ||||||
* that any move will always proceed in the "right" direction. | ||||||
* This is achieved by exploiting the fact that comparisons with NaN always result in false. | ||||||
*/ | ||||||
if constexpr (any_missing) default_left_[nidx_array] = 0; | ||||||
if constexpr (has_categorical) is_cat_[nidx_array] = 0; | ||||||
split_cond_[nidx_array] = std::numeric_limits<float>::quiet_NaN(); | ||||||
|
||||||
Populate<depth + 1>(tree, cats, 2 * nidx_array + 2, nidx); | ||||||
} else { | ||||||
if constexpr (any_missing) default_left_[nidx_array] = DefaultLeft(tree, nidx); | ||||||
if constexpr (has_categorical) { | ||||||
is_cat_[nidx_array] = common::IsCat(cats.split_type, nidx); | ||||||
if (is_cat_[nidx_array]) { | ||||||
cat_segment_[nidx_array] = cats.categories.subspan(cats.node_ptr[nidx].beg, | ||||||
cats.node_ptr[nidx].size); | ||||||
} | ||||||
} | ||||||
|
||||||
split_index_[nidx_array] = SplitIndex(tree, nidx); | ||||||
split_cond_[nidx_array] = SplitCond(tree, nidx); | ||||||
|
||||||
/* | ||||||
* LeftChild is used to determine if a node is a leaf, so it is always a valid value. | ||||||
* However, RightChild can be invalid in some exotic cases. | ||||||
* A tree with an invalid RightChild can still be correctly processed using classical methods | ||||||
* if the split conditions are correct. | ||||||
* However, in an array layout, an invalid RightChild, even if unreachable, can lead to memory corruption. | ||||||
* A check should be added to prevent this. | ||||||
*/ | ||||||
Populate<depth + 1>(tree, cats, 2 * nidx_array + 1, LeftChild(tree, nidx)); | ||||||
bst_node_t right_child = RightChild(tree, nidx); | ||||||
if (right_child != RegTree::kInvalidNodeId) { | ||||||
Populate<depth + 1>(tree, cats, 2 * nidx_array + 2, right_child); | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
bool inline GetDecision(float fvalue, bst_node_t nidx) const { | ||||||
if constexpr (has_categorical) { | ||||||
if (is_cat_[nidx]) { | ||||||
return common::Decision(cat_segment_[nidx], fvalue); | ||||||
} else { | ||||||
return fvalue < split_cond_[nidx]; | ||||||
} | ||||||
} else { | ||||||
return fvalue < split_cond_[nidx]; | ||||||
} | ||||||
} | ||||||
|
||||||
public: | ||||||
/* Ad-hoc value. | ||||||
* Increasing doesn't lead to perf gain, since bottleneck is now at gather instructions. | ||||||
*/ | ||||||
constexpr static int kMaxNumDeepLevels = 6; | ||||||
static_assert(kNumDeepLevels <= kMaxNumDeepLevels); | ||||||
|
||||||
ArrayTreeLayout(const TreeType& tree, RegTree::CategoricalSplitMatrix const &cats) { | ||||||
Populate(tree, cats); | ||||||
} | ||||||
|
||||||
/** | ||||||
* @brief | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
* Traverse the top levels of the tree for the entire block_size. | ||||||
* In the array layout, it is organized to guarantee that | ||||||
* if a node at the current level has index nidx, then | ||||||
* the node index for the left child at the next level is always 2*nidx, and | ||||||
* the node index for the right child at the next level is always 2*nidx+1. | ||||||
* This greatly improves data locality. | ||||||
* | ||||||
* @param thread_temp buffer holding the feature values | ||||||
* | ||||||
* @param offset offset of the current data block | ||||||
* | ||||||
* @param block_size size of the current block (1 < block_size <= 64) | ||||||
* | ||||||
* @param p_nidx pointer to the vector of node indexes in the original tree, | ||||||
* corresponding to the level next after kNumDeepLevels | ||||||
*/ | ||||||
void inline Process(std::vector<RegTree::FVec> const &thread_temp, std::size_t const offset, | ||||||
std::size_t const block_size, bst_node_t* p_nidx) { | ||||||
for (int depth = 0; depth < kNumDeepLevels; ++depth) { | ||||||
std::size_t first_node = (1u << depth) - 1; | ||||||
|
||||||
for (std::size_t i = 0; i < block_size; ++i) { | ||||||
bst_node_t idx = p_nidx[i]; | ||||||
|
||||||
const auto& feat = thread_temp[offset + i]; | ||||||
bst_feature_t split = split_index_[first_node + idx]; | ||||||
auto fvalue = feat.GetFvalue(split); | ||||||
if constexpr (any_missing) { | ||||||
bool go_left = feat.IsMissing(split) ? default_left_[first_node + idx] | ||||||
: GetDecision(fvalue, first_node + idx); | ||||||
p_nidx[i] = 2 * idx + !go_left; | ||||||
} else { | ||||||
p_nidx[i] = 2 * idx + !GetDecision(fvalue, first_node + idx); | ||||||
} | ||||||
} | ||||||
} | ||||||
for (std::size_t i = 0; i < block_size; ++i) { | ||||||
p_nidx[i] = nidx_in_tree_[p_nidx[i]]; | ||||||
} | ||||||
} | ||||||
}; | ||||||
|
||||||
template <class TreeType, bool has_categorical, bool any_missing, int num_deep_levels = 1> | ||||||
void inline ProcessArrayTree(const TreeType& tree, RegTree::CategoricalSplitMatrix const &cats, | ||||||
std::vector<RegTree::FVec> const &thread_temp, | ||||||
std::size_t const offset, std::size_t const block_size, | ||||||
bst_node_t* p_nidx, int tree_depth) { | ||||||
constexpr int kMaxNumDeepLevels = | ||||||
ArrayTreeLayout<TreeType, has_categorical, any_missing, 0>::kMaxNumDeepLevels; | ||||||
|
||||||
if constexpr (num_deep_levels == kMaxNumDeepLevels) { | ||||||
ArrayTreeLayout<TreeType, has_categorical, any_missing, num_deep_levels> buffer(tree, cats); | ||||||
buffer.Process(thread_temp, offset, block_size, p_nidx); | ||||||
} else { | ||||||
if (tree_depth <= num_deep_levels) { | ||||||
ArrayTreeLayout<TreeType, has_categorical, any_missing, num_deep_levels> buffer(tree, cats); | ||||||
buffer.Process(thread_temp, offset, block_size, p_nidx); | ||||||
} else { | ||||||
ProcessArrayTree<TreeType, has_categorical, any_missing, num_deep_levels + 1> | ||||||
(tree, cats, thread_temp, offset, block_size, p_nidx, tree_depth); | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
} // namespace xgboost::predictor | ||||||
#endif // XGBOOST_PREDICTOR_ARRAY_TREE_LAYOUT_H_ |
Uh oh!
There was an error while loading. Please reload this page.