Skip to content

Conversation

razdoburdin
Copy link
Contributor

This PR introduces optimization for CPU inference. For each tree, the top N levels are transformed into a compact array-based layout. This allows for a branchless node indexing rule: idx = 2 * idx + int(val < split_cond). To minimize memory overhead, this transformation from the standard tree structure to the array layout is performed on-the-fly for each block of data being processed. Even with this additional calculations, improved data locality in the cache-friendly array layout leads to inference speed up to ~2x (x1.4 on average).
image

@razdoburdin razdoburdin marked this pull request as draft June 20, 2025 13:50
@trivialfis
Copy link
Member

Thank you for the optimization on the inference. Please unmark the "draft" status and ping me when the PR is ready for testing.

Copy link

@Vika-F Vika-F left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cosmetic changes.

The next possible step would be to convert the trees into array-based representation only once, and not to do it for each block of data.

razdoburdin and others added 6 commits June 24, 2025 12:53
Co-authored-by: Victoriya Fedotova <viktoria.nn@gmail.com>
Co-authored-by: Victoriya Fedotova <viktoria.nn@gmail.com>
Co-authored-by: Victoriya Fedotova <viktoria.nn@gmail.com>
Co-authored-by: Victoriya Fedotova <viktoria.nn@gmail.com>
Co-authored-by: Victoriya Fedotova <viktoria.nn@gmail.com>
Co-authored-by: Victoriya Fedotova <viktoria.nn@gmail.com>
razdoburdin and others added 2 commits June 24, 2025 12:57
Co-authored-by: Victoriya Fedotova <viktoria.nn@gmail.com>
@razdoburdin
Copy link
Contributor Author

The next possible step would be to convert the trees into array-based representation only once, and not to do it for each block of data.

it sounds reasonable and will further improve perf (by cost of increasing memory consumption).

@razdoburdin razdoburdin marked this pull request as ready for review June 24, 2025 12:24
@razdoburdin
Copy link
Contributor Author

Thank you for the optimization on the inference. Please unmark the "draft" status and ping me when the PR is ready for testing.

hi @trivialfis, the PR is ready for review.

@trivialfis
Copy link
Member

cc @hcho3

Copy link
Member

@trivialfis trivialfis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still trying to understand the code, will give it a try later. In the meanwhile, could you please craft some specific unittests for the new inference algorithm?

* We use transforming trees to array layout for each block of data to avoid memory overheads.
* It makes the array layout inefficient for block_size == 1
*/
const bool use_array_tree_layout = block_size > 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if this is a small online inference call? The input size could be a few samples per call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default (the old one) realization will be used

for (std::size_t i = 0; i < block_size; ++i) {
bst_node_t nidx = 0;
if constexpr (use_array_tree_layout) {
nidx = p_nidx[i];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The optimized array_layout processing is effective only for nodes, that are close to root. For other nodes, we still use the original method.

@razdoburdin
Copy link
Contributor Author

Still trying to understand the code, will give it a try later. In the meanwhile, could you please craft some specific unittests for the new inference algorithm?

I added some unit-tests.

Copy link
Member

@trivialfis trivialfis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still trying to understand the code, in the meantime, let me do some refactoring in this and the next week to accommodate the new optimization. We need a better structure to handle all these:

  • Predict with scalar leaf.
  • Predict with vector leaf.
  • Array predict with scalar leaf.
  • Array predict with vector leaf.
  • Column split with scalar leaf.

I think I will split up the CPU predictor into multiple pieces.

* If the tree has additional levels, this array stores the node indices of the sub-trees at level kNumDeepLevels.
* This is necessary to continue processing nodes that are not eligible for array-based unrolling.
* The number of sub-trees packed into this array is equal to the number of nodes at tree level kNumDeepLevels,
* which is calculated as (1u << kNumDeepLevels) == kNodesCount + 1.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the tree is not well balanced and is more like a linked list than a tree?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case we add dummy nodes with nan as split condition. In this dummy node the decision is always "go right" (see also comments https://github.com/razdoburdin/xgboost/blob/b0eaa856e1246416f7f9538bcc004e7723d9b997/src/predictor/array_tree_layout.h#L154), the left child are not initialized.

So with array layout we have to allocated all nodes, but keep some of the unpopulated in case the tree is pure balanced.

Initial tree:
image

Tree with dummy (nan-valued nodes):
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for sharing, that makes sense.

*/
std::array<bst_node_t, kNodesCount + 1> nidx_in_tree_;

static bool IsLeaf(const RegTree& tree, bst_node_t nidx) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a benefit of doing this C++ overloading rather than the simpler tree.IsLeaf? How much faster are we seeing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did the overload to handle both RegTree and MultiTargetTree cases. Is there a better option?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use RegTree without extracting the Multi-target tree when populating the buffer, and delegate the dispatching to RegTree::LeftChild(bst_node_t nidx) instead of using the RegTree::Node::LeftChild. There's a check inside the RegTree::LeftChild:

  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
    if (IsMultiTarget()) {
      return this->p_mt_tree_->LeftChild(nidx);
    }
    return (*this)[nidx].LeftChild();
  }

@trivialfis
Copy link
Member

I'm trying to cleanup the CPU predictor. I will update this PR once it is finished.

@trivialfis
Copy link
Member

I need to fix a perf regression caused by the new ordinal encoder.

@trivialfis
Copy link
Member

I need to fix a perf regression caused by the new ordinal encoder.

This has been fixed. I will look deeper into this PR.

using DefaultLeftType =
typename std::conditional_t<any_missing,
std::array<uint8_t, kNodesCount>,
struct Empty>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct Empty>;
Empty>;

* If the tree has additional levels, this array stores the node indices of the sub-trees at level kNumDeepLevels.
* This is necessary to continue processing nodes that are not eligible for array-based unrolling.
* The number of sub-trees packed into this array is equal to the number of nodes at tree level kNumDeepLevels,
* which is calculated as (1u << kNumDeepLevels) == kNodesCount + 1.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for sharing, that makes sense.

*/
std::array<bst_node_t, kNodesCount + 1> nidx_in_tree_;

static bool IsLeaf(const RegTree& tree, bst_node_t nidx) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use RegTree without extracting the Multi-target tree when populating the buffer, and delegate the dispatching to RegTree::LeftChild(bst_node_t nidx) instead of using the RegTree::Node::LeftChild. There's a check inside the RegTree::LeftChild:

  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
    if (IsMultiTarget()) {
      return this->p_mt_tree_->LeftChild(nidx);
    }
    return (*this)[nidx].LeftChild();
  }

@trivialfis
Copy link
Member

trivialfis commented Aug 20, 2025

Thank you for expanding the tree layout. In the future (when you can prioritize it), do you think it's possible to create and store the layout inside the RegTree structure as an opt-in method call? My reasoning is as follows:

  • The existing RegTree and the multi-target tree already use a very similar layout, minus the dummy nodes. It might be easier/cleaner to do it there.
  • We can avoid complicating the predictor too much.
  • We can cache the result in the regtree structure to avoid repeated initialization.

You can define a std::unique_ptr<ArrayTree> inside the RegTree, set it to nullptr. Define a method to create the array tree when needed, and reset it back to nullptr if any non-const method is called.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants