-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
optimize CPU inference with Array-Based Tree Traversal #11519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
razdoburdin
wants to merge
32
commits into
dmlc:master
Choose a base branch
from
razdoburdin:dev/cpu/eytzinger_layout
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 20 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
e64e20c
basic implementation
60c2ffe
optimisations
8f6dfe3
fix compilation error
bd13491
perf optimzation
3827a49
add categorial
7334bd2
add multitarget
8356855
linting
165b34a
perf
52eee0c
fix perf
cb28530
refactoring
7ae3a42
add comments
2799644
more comments
a8bb91e
fix and tildy
6d94176
Update src/predictor/array_tree_layout.h
razdoburdin e34becc
add static assertions
a2f2c75
fix randome state usage in sycl training_continuation test
2afad25
Merge branch 'master' into dev/cpu/eytzinger_layout
razdoburdin 92ac69e
check if right child is valid
e2b0f05
Merge branch 'dev/cpu/eytzinger_layout' of https://github.com/razdobu…
87bee15
use signed ints for node indxes
c3c1c85
Update src/predictor/array_tree_layout.h
razdoburdin d270ee7
Update src/predictor/array_tree_layout.h
razdoburdin 2a7e575
Update src/predictor/array_tree_layout.h
razdoburdin 3539ec0
Update src/predictor/array_tree_layout.h
razdoburdin 709d233
Update src/predictor/array_tree_layout.h
razdoburdin 40be7e2
Update src/predictor/array_tree_layout.h
razdoburdin c9160c6
Update src/predictor/cpu_predictor.cc
razdoburdin de552e8
linting
9c1007f
add tests
92b5069
lint
b0eaa85
Update src/predictor/cpu_predictor.cc
razdoburdin 790a98e
Merge branch 'master' into dev/cpu/eytzinger_layout
razdoburdin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,273 @@ | ||||||
/** | ||||||
* Copyright 2021-2025, XGBoost Contributors | ||||||
* \file array_tree_layout.cc | ||||||
* \brief Implementation of array tree layout -- a powerfull inference optimization method. | ||||||
*/ | ||||||
#ifndef XGBOOST_PREDICTOR_ARRAY_TREE_LAYOUT_H_ | ||||||
#define XGBOOST_PREDICTOR_ARRAY_TREE_LAYOUT_H_ | ||||||
|
||||||
#include <limits> | ||||||
#include <vector> | ||||||
|
||||||
namespace xgboost::predictor { | ||||||
|
||||||
/** | ||||||
* @brief The class holds the array-based representation of the top levels of a single tree. | ||||||
* | ||||||
* \tparam TreeType The type of the origianl tree (RegTree or MultiTargetTree) | ||||||
* | ||||||
* \tparam has_categorical if the tree has categorical features | ||||||
* | ||||||
* \tparam any_missing if the class is able to process missing values | ||||||
* | ||||||
* \tparam kNumDeepLevels number of tree leveles being unrolled into array-based structure | ||||||
*/ | ||||||
template <class TreeType, bool has_categorical, bool any_missing, int kNumDeepLevels> | ||||||
class ArrayTreeLayout { | ||||||
private: | ||||||
constexpr static size_t kNodesCount = (1u << kNumDeepLevels) - 1; | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
struct Empty {}; | ||||||
using DefaultLeftType = | ||||||
typename std::conditional_t<any_missing, | ||||||
std::array<uint8_t, kNodesCount>, | ||||||
struct Empty>; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
using IsCatType = | ||||||
typename std::conditional_t<has_categorical, | ||||||
std::array<uint8_t, kNodesCount>, | ||||||
struct Empty>; | ||||||
using CatSegmentType = | ||||||
typename std::conditional_t<has_categorical, | ||||||
std::array<common::Span<uint32_t const>, kNodesCount>, | ||||||
struct Empty>; | ||||||
|
||||||
DefaultLeftType default_left_; | ||||||
IsCatType is_cat_; | ||||||
CatSegmentType cat_segment_; | ||||||
|
||||||
std::array<bst_feature_t, kNodesCount> split_index_; | ||||||
std::array<float, kNodesCount> split_cond_; | ||||||
std::array<bst_node_t, kNodesCount + 1> nidx_in_tree_; | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
inline static bool IsLeaf(const RegTree& tree, bst_node_t nidx) { | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
static_assert(std::is_same_v<RegTree, TreeType>); | ||||||
return tree[nidx].IsLeaf(); | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
} | ||||||
|
||||||
inline static bool IsLeaf(const MultiTargetTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<MultiTargetTree, TreeType>); | ||||||
return tree.IsLeaf(nidx); | ||||||
} | ||||||
|
||||||
inline static uint8_t DefaultLeft(const RegTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<RegTree, TreeType>); | ||||||
return tree[nidx].DefaultLeft(); | ||||||
} | ||||||
|
||||||
inline static uint8_t DefaultLeft(const MultiTargetTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<MultiTargetTree, TreeType>); | ||||||
return tree.DefaultLeft(nidx); | ||||||
} | ||||||
|
||||||
inline static bst_feature_t SplitIndex(const RegTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<RegTree, TreeType>); | ||||||
return tree[nidx].SplitIndex(); | ||||||
} | ||||||
|
||||||
inline static bst_feature_t SplitIndex(const MultiTargetTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<MultiTargetTree, TreeType>); | ||||||
return tree.SplitIndex(nidx); | ||||||
} | ||||||
|
||||||
inline static float SplitCond(const RegTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<RegTree, TreeType>); | ||||||
return tree[nidx].SplitCond(); | ||||||
} | ||||||
|
||||||
inline static float SplitCond(const MultiTargetTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<MultiTargetTree, TreeType>); | ||||||
return tree.SplitCond(nidx); | ||||||
} | ||||||
|
||||||
inline static bst_node_t LeftChild(const RegTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<RegTree, TreeType>); | ||||||
return tree[nidx].LeftChild(); | ||||||
} | ||||||
|
||||||
inline static bst_node_t LeftChild(const MultiTargetTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<MultiTargetTree, TreeType>); | ||||||
return tree.LeftChild(nidx); | ||||||
} | ||||||
|
||||||
inline static bst_node_t RightChild(const RegTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<RegTree, TreeType>); | ||||||
return tree[nidx].LeftChild() + 1; | ||||||
} | ||||||
|
||||||
inline static bst_node_t RightChild(const MultiTargetTree& tree, bst_node_t nidx) { | ||||||
static_assert(std::is_same_v<MultiTargetTree, TreeType>); | ||||||
return tree.RightChild(nidx); | ||||||
} | ||||||
|
||||||
/** | ||||||
* @brief Traverse the top levels of original tree and fill internal arrays | ||||||
* | ||||||
* \tparam TreeType The type of the origianl tree (RegTree or MultiTargetTree) | ||||||
* | ||||||
* \tparam depth the tree level being processing | ||||||
* | ||||||
* \param tree the original tree | ||||||
* | ||||||
* \param cats matrix of categorical splits | ||||||
* | ||||||
* \param nidx_array node idx in the array layout | ||||||
* | ||||||
* \param nidx node idx in the original tree | ||||||
* | ||||||
*/ | ||||||
template <int depth = 0> | ||||||
void inline Populate(const TreeType& tree, RegTree::CategoricalSplitMatrix const &cats, | ||||||
bst_node_t nidx_array = 0, bst_node_t nidx = 0) { | ||||||
if constexpr (depth == kNumDeepLevels + 1) { | ||||||
return; | ||||||
} else if constexpr (depth == kNumDeepLevels) { | ||||||
/* We save the node index in the origianl tree to able to continue processing | ||||||
* for nodes not egligable for array layout optimisation. | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
*/ | ||||||
nidx_in_tree_[nidx_array - kNodesCount] = nidx; | ||||||
} else { | ||||||
if (IsLeaf(tree, nidx)) { | ||||||
split_index_[nidx_array] = 0; | ||||||
|
||||||
/* | ||||||
* if the tree is not fully populated, we can reduce transfering costs. | ||||||
* the values for unpopulated part of the tree are set in a way to guarantie | ||||||
* that a moove will always done in "right" direction. | ||||||
* here we exploiting that comparison with nan always results to false. | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
*/ | ||||||
if constexpr (any_missing) default_left_[nidx_array] = 0; | ||||||
if constexpr (has_categorical) is_cat_[nidx_array] = 0; | ||||||
split_cond_[nidx_array] = std::numeric_limits<float>::quiet_NaN(); | ||||||
|
||||||
Populate<depth + 1>(tree, cats, 2 * nidx_array + 2, nidx); | ||||||
} else { | ||||||
if constexpr (any_missing) default_left_[nidx_array] = DefaultLeft(tree, nidx); | ||||||
if constexpr (has_categorical) { | ||||||
is_cat_[nidx_array] = common::IsCat(cats.split_type, nidx); | ||||||
if (is_cat_[nidx_array]) { | ||||||
cat_segment_[nidx_array] = cats.categories.subspan(cats.node_ptr[nidx].beg, | ||||||
cats.node_ptr[nidx].size); | ||||||
} | ||||||
} | ||||||
|
||||||
split_index_[nidx_array] = SplitIndex(tree, nidx); | ||||||
split_cond_[nidx_array] = SplitCond(tree, nidx); | ||||||
|
||||||
/* | ||||||
* LeftChild is used to find if the node is leaf, so it is a valid value, | ||||||
* howerwer RightChild can be invalid in some exotic case. | ||||||
* The tree with invalid right-child can be correctly processed by a classical method, | ||||||
* if the split conditions are propper. | ||||||
* But for array layout invalid RightChild, even unreachable, will lead to memory corruption. | ||||||
* Add check to prevent it. | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
*/ | ||||||
Populate<depth + 1>(tree, cats, 2 * nidx_array + 1, LeftChild(tree, nidx)); | ||||||
bst_node_t right_child = RightChild(tree, nidx); | ||||||
if (right_child != RegTree::kInvalidNodeId) { | ||||||
Populate<depth + 1>(tree, cats, 2 * nidx_array + 2, right_child); | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
bool inline GetDecision(float fvalue, bst_node_t nidx) const { | ||||||
if constexpr (has_categorical) { | ||||||
if (is_cat_[nidx]) { | ||||||
return common::Decision(cat_segment_[nidx], fvalue); | ||||||
} else { | ||||||
return fvalue < split_cond_[nidx]; | ||||||
} | ||||||
} else { | ||||||
return fvalue < split_cond_[nidx]; | ||||||
} | ||||||
} | ||||||
|
||||||
public: | ||||||
/* Ad-hoc value. | ||||||
* Increasing doesn't lead to perf gain, since bottleneck is now at gather instructions. | ||||||
*/ | ||||||
constexpr static int kMaxNumDeepLevels = 6; | ||||||
static_assert(kNumDeepLevels <= kMaxNumDeepLevels); | ||||||
|
||||||
ArrayTreeLayout(const TreeType& tree, RegTree::CategoricalSplitMatrix const &cats) { | ||||||
Populate(tree, cats); | ||||||
} | ||||||
|
||||||
/** | ||||||
* @brief | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
* Traverse top levels of the tree for an entire block_size. | ||||||
* In array layout is orginised to garantie that | ||||||
* if the node at the current level has index nidx, than | ||||||
* the node index for left child at the next level is always 2*nidx | ||||||
* the node index for right child at the next level is always 2*nidx+1 | ||||||
* This greatly improve data locality | ||||||
razdoburdin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
* | ||||||
* \param thread_temp buffer holding the feature values | ||||||
* | ||||||
* \param offset offset of the current data block | ||||||
* | ||||||
* \param block_size size of the current block (1 < block_size <= 64) | ||||||
* | ||||||
* \param p_nidx pointer to the vector of node indexes in the original tree, | ||||||
* corresponding to the level next after kNumDeepLevels | ||||||
*/ | ||||||
void inline Process(std::vector<RegTree::FVec> const &thread_temp, std::size_t const offset, | ||||||
std::size_t const block_size, bst_node_t* p_nidx) { | ||||||
for (int depth = 0; depth < kNumDeepLevels; ++depth) { | ||||||
std::size_t first_node = (1u << depth) - 1; | ||||||
|
||||||
for (std::size_t i = 0; i < block_size; ++i) { | ||||||
bst_node_t idx = p_nidx[i]; | ||||||
|
||||||
const auto& feat = thread_temp[offset + i]; | ||||||
bst_feature_t split = split_index_[first_node + idx]; | ||||||
auto fvalue = feat.GetFvalue(split); | ||||||
if constexpr (any_missing) { | ||||||
bool go_left = feat.IsMissing(split) ? default_left_[first_node + idx] | ||||||
: GetDecision(fvalue, first_node + idx); | ||||||
p_nidx[i] = 2 * idx + !go_left; | ||||||
} else { | ||||||
p_nidx[i] = 2 * idx + !GetDecision(fvalue, first_node + idx); | ||||||
} | ||||||
} | ||||||
} | ||||||
for (std::size_t i = 0; i < block_size; ++i) { | ||||||
p_nidx[i] = nidx_in_tree_[p_nidx[i]]; | ||||||
} | ||||||
} | ||||||
}; | ||||||
|
||||||
template <class TreeType, bool has_categorical, bool any_missing, int num_deep_levels = 1> | ||||||
void inline ProcessArrayTree(const TreeType& tree, RegTree::CategoricalSplitMatrix const &cats, | ||||||
std::vector<RegTree::FVec> const &thread_temp, | ||||||
std::size_t const offset, std::size_t const block_size, | ||||||
bst_node_t* p_nidx, int tree_depth) { | ||||||
constexpr int kMaxNumDeepLevels = | ||||||
ArrayTreeLayout<TreeType, has_categorical, any_missing, 0>::kMaxNumDeepLevels; | ||||||
|
||||||
if constexpr (num_deep_levels == kMaxNumDeepLevels) { | ||||||
ArrayTreeLayout<TreeType, has_categorical, any_missing, num_deep_levels> buffer(tree, cats); | ||||||
buffer.Process(thread_temp, offset, block_size, p_nidx); | ||||||
} else { | ||||||
if (tree_depth <= num_deep_levels) { | ||||||
ArrayTreeLayout<TreeType, has_categorical, any_missing, num_deep_levels> buffer(tree, cats); | ||||||
buffer.Process(thread_temp, offset, block_size, p_nidx); | ||||||
} else { | ||||||
ProcessArrayTree<TreeType, has_categorical, any_missing, num_deep_levels + 1> | ||||||
(tree, cats, thread_temp, offset, block_size, p_nidx, tree_depth); | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
} // namespace xgboost::predictor | ||||||
#endif // XGBOOST_PREDICTOR_ARRAY_TREE_LAYOUT_H_ |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.