Skip to content

Commit 6608ec9

Browse files
Max Kaplanfacebook-github-bot
authored andcommitted
Add CUDA implementation for masked_select_jagged_1d() (#5179)
Summary: X-link: facebookresearch/FBGEMM#2173 Implemented CUDA kernel for masked_select_jagged_1d operation to enable GPU acceleration for jagged tensor masking operations. Differential Revision: D88175195
1 parent 7cbaff6 commit 6608ec9

File tree

3 files changed

+205
-10
lines changed

3 files changed

+205
-10
lines changed

fbgemm_gpu/bench/jagged_tensor_benchmark.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -495,22 +495,25 @@ def ref(values: torch.Tensor, lengths: torch.Tensor, max_len: int) -> torch.Tens
495495
@cli.command()
496496
@click.option("--batch-size", type=int, default=1024)
497497
@click.option("--max-len", type=int, default=256)
498+
@click.option("--use-cpu", is_flag=True, default=False)
498499
def masked_select_jagged_1d(
499500
batch_size: int,
500501
max_len: int,
502+
use_cpu: bool,
501503
) -> None:
502-
lengths = torch.randint(2 * max_len, size=(batch_size,)) # Allow for truncation
504+
device = "cpu" if use_cpu else "cuda"
505+
lengths = torch.randint(2 * max_len, size=(batch_size,), device=device)
503506
total_lengths = int(lengths.sum().item())
504507
dtype = torch.long
505-
values = torch.randint(2**16, (total_lengths,), dtype=dtype)
506-
mask = torch.randint(2, (total_lengths,)) > 0
508+
values = torch.randint(2**16, (total_lengths,), dtype=dtype, device=device)
509+
mask = torch.randint(2, (total_lengths,), device=device) > 0
507510

508511
def ref(
509512
values: torch.Tensor, lengths: torch.Tensor, mask: torch.Tensor
510513
) -> tuple[torch.Tensor, torch.Tensor]:
511514
masked_values_ref = values[mask]
512515
cum_count = torch.cumsum(mask, 0)
513-
cum_count = torch.cat((cum_count, torch.tensor([0])))
516+
cum_count = torch.cat((cum_count, torch.tensor([0], device=values.device)))
514517
cum_length = cum_count[torch.cumsum(lengths, 0) - 1]
515518
cum_length_shift_right = torch.roll(cum_length, 1)
516519
cum_length_shift_right[0] = 0
@@ -532,8 +535,10 @@ def ref(
532535

533536
bytes = (2 * values.numel() + 2 * lengths.numel() + 2 * masked_values.numel()) * 4
534537

535-
logging.info(f"reference {time_ref} sec {bytes / time_ref / 1e9} GB/s")
536-
logging.info(f"masked_select_jagged_1d {time} sec {bytes / time / 1e9} GB/s")
538+
logging.info(f"[{device}] reference {time_ref} sec {bytes / time_ref / 1e9} GB/s")
539+
logging.info(
540+
f"[{device}] masked_select_jagged_1d {time} sec {bytes / time / 1e9} GB/s"
541+
)
537542

538543

539544
@cli.command()
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "common.cuh"
10+
11+
using Tensor = at::Tensor;
12+
13+
namespace fbgemm_gpu {
14+
15+
template <typename index_t>
16+
__global__
17+
__launch_bounds__(kMaxThreads) void masked_select_jagged_1d_lengths_kernel(
18+
const index_t* __restrict__ lengths,
19+
const bool* __restrict__ mask,
20+
index_t* __restrict__ masked_lengths,
21+
const index_t* __restrict__ input_offsets,
22+
const index_t batch_size) {
23+
const index_t batch_idx = blockIdx.x;
24+
25+
if (batch_idx >= batch_size) {
26+
return;
27+
}
28+
29+
const index_t input_offset = input_offsets[batch_idx];
30+
const index_t input_len = lengths[batch_idx];
31+
32+
int32_t local_count = 0;
33+
for (index_t i = threadIdx.x; i < input_len; i += blockDim.x) {
34+
const index_t input_idx = input_offset + i;
35+
36+
if (mask[input_idx]) {
37+
local_count++;
38+
}
39+
}
40+
41+
__shared__ int32_t shared_counts[kMaxThreads];
42+
shared_counts[threadIdx.x] = local_count;
43+
__syncthreads();
44+
45+
for (auto stride = blockDim.x / 2; stride > 0; stride /= 2) {
46+
if (threadIdx.x < stride) {
47+
shared_counts[threadIdx.x] += shared_counts[threadIdx.x + stride];
48+
}
49+
__syncthreads();
50+
}
51+
52+
if (threadIdx.x == 0) {
53+
masked_lengths[batch_idx] = static_cast<index_t>(shared_counts[0]);
54+
}
55+
}
56+
57+
template <typename index_t, typename scalar_t>
58+
__global__
59+
__launch_bounds__(kMaxThreads) void masked_select_jagged_1d_values_kernel(
60+
const scalar_t* __restrict__ values,
61+
const index_t* __restrict__ lengths,
62+
const bool* __restrict__ mask,
63+
scalar_t* __restrict__ masked_values,
64+
const index_t* __restrict__ input_offsets,
65+
const index_t* __restrict__ output_offsets,
66+
const index_t batch_size) {
67+
const index_t batch_idx = blockIdx.x;
68+
69+
if (batch_idx >= batch_size) {
70+
return;
71+
}
72+
73+
const index_t input_offset = input_offsets[batch_idx];
74+
const index_t output_offset = output_offsets[batch_idx];
75+
const index_t input_len = lengths[batch_idx];
76+
77+
int32_t write_pos = 0;
78+
79+
for (index_t i = 0; i < input_len; i++) {
80+
const index_t input_idx = input_offset + i;
81+
82+
const bool is_masked = mask[input_idx];
83+
84+
if (threadIdx.x == 0 && is_masked) {
85+
const index_t output_idx = output_offset + write_pos;
86+
87+
masked_values[output_idx] = values[input_idx];
88+
write_pos++;
89+
}
90+
}
91+
}
92+
93+
std::tuple<Tensor, Tensor> masked_select_jagged_1d_cuda(
94+
const Tensor& values,
95+
const Tensor& lengths,
96+
const Tensor& mask,
97+
const std::optional<bool> check_length) {
98+
TENSOR_ON_CUDA_GPU(values);
99+
TENSOR_ON_CUDA_GPU(lengths);
100+
TENSOR_ON_CUDA_GPU(mask);
101+
102+
TORCH_CHECK(values.dim() == 1);
103+
TORCH_CHECK(lengths.dim() == 1);
104+
TORCH_CHECK(mask.dim() == 1);
105+
106+
if (check_length.has_value() && check_length.value()) {
107+
TORCH_CHECK(
108+
mask.numel() == values.numel(),
109+
"mask and values should have the same numel, but got mask numel: ",
110+
mask.numel(),
111+
" values numel: ",
112+
values.numel());
113+
}
114+
115+
const auto batch_size = lengths.numel();
116+
Tensor masked_lengths = at::empty_like(lengths);
117+
118+
if (batch_size == 0) {
119+
Tensor masked_values = at::empty({0}, values.options());
120+
return {masked_values, masked_lengths};
121+
}
122+
123+
Tensor input_offsets = asynchronous_complete_cumsum_gpu(lengths);
124+
125+
TORCH_CHECK(
126+
input_offsets.numel() == batch_size + 1,
127+
"input_offsets should have size batch_size+1, got ",
128+
input_offsets.numel(),
129+
" expected ",
130+
batch_size + 1);
131+
132+
Tensor mask_int = mask.to(at::kInt);
133+
Tensor mask_cumsum = asynchronous_complete_cumsum_gpu(mask_int);
134+
const int32_t num_outputs = mask_cumsum[-1].item<int32_t>();
135+
Tensor masked_values = at::empty({num_outputs}, values.options());
136+
137+
AT_DISPATCH_INDEX_TYPES(
138+
lengths.scalar_type(), "masked_select_jagged_1d_lengths", [&] {
139+
const int num_blocks = batch_size;
140+
// First pass: compute masked lengths
141+
FBGEMM_LAUNCH_KERNEL(
142+
(masked_select_jagged_1d_lengths_kernel<index_t>),
143+
num_blocks,
144+
kMaxThreads,
145+
0,
146+
at::cuda::getCurrentCUDAStream(),
147+
lengths.data_ptr<index_t>(),
148+
mask.data_ptr<bool>(),
149+
masked_lengths.data_ptr<index_t>(),
150+
input_offsets.data_ptr<index_t>(),
151+
static_cast<index_t>(batch_size));
152+
153+
Tensor output_offsets =
154+
asynchronous_complete_cumsum_gpu(masked_lengths);
155+
156+
TORCH_CHECK(
157+
output_offsets.numel() == batch_size + 1,
158+
"output_offsets should have size batch_size+1, got ",
159+
output_offsets.numel(),
160+
" expected ",
161+
batch_size + 1);
162+
163+
// Second pass: write masked values
164+
FBGEMM_DISPATCH_ALL_TYPES(
165+
values.scalar_type(), "masked_select_jagged_1d_values", [&] {
166+
FBGEMM_LAUNCH_KERNEL(
167+
(masked_select_jagged_1d_values_kernel<index_t, scalar_t>),
168+
num_blocks,
169+
1, // Use single thread per block for simplicity
170+
0,
171+
at::cuda::getCurrentCUDAStream(),
172+
values.data_ptr<scalar_t>(),
173+
lengths.data_ptr<index_t>(),
174+
mask.data_ptr<bool>(),
175+
masked_values.data_ptr<scalar_t>(),
176+
input_offsets.data_ptr<index_t>(),
177+
output_offsets.data_ptr<index_t>(),
178+
static_cast<index_t>(batch_size));
179+
});
180+
});
181+
182+
return {masked_values, masked_lengths};
183+
}
184+
185+
} // namespace fbgemm_gpu
186+
187+
FBGEMM_OP_DISPATCH(
188+
CUDA,
189+
"masked_select_jagged_1d",
190+
fbgemm_gpu::masked_select_jagged_1d_cuda);

fbgemm_gpu/test/jagged/misc_ops_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_jagged_1d_to_truncated_values(
9595
index_dtype=st.sampled_from([torch.int, torch.long]),
9696
jagged_tensor_dtype=st.sampled_from([torch.int, torch.long]),
9797
empty_lengths=st.booleans(),
98-
use_cpu=st.just(True),
98+
use_cpu=st.booleans() if torch.cuda.is_available() else st.just(True),
9999
)
100100
@settings(max_examples=20, deadline=None)
101101
def test_masked_select_jagged_1d(
@@ -118,15 +118,15 @@ def test_masked_select_jagged_1d(
118118
dtype=index_dtype,
119119
device=device,
120120
)
121-
lengths[batch_size // 2] = 0 # test a corner case
121+
lengths[batch_size // 2] = 0
122122
n = int(lengths.sum().item())
123123
values = torch.randint(
124124
2**16,
125125
(n,),
126126
dtype=jagged_tensor_dtype,
127127
device=device,
128128
)
129-
mask = torch.randint(2, (n,)) > 0
129+
mask = torch.randint(2, (n,), device=device) > 0
130130

131131
masked_values, masked_lengths = torch.ops.fbgemm.masked_select_jagged_1d(
132132
values,
@@ -136,7 +136,7 @@ def test_masked_select_jagged_1d(
136136

137137
masked_values_ref = values[mask]
138138
cum_count = torch.cumsum(mask, 0)
139-
cum_count = torch.cat((cum_count, torch.tensor([0])))
139+
cum_count = torch.cat((cum_count, torch.tensor([0], device=device)))
140140
cum_length = cum_count[torch.cumsum(lengths, 0) - 1]
141141
cum_length_shift_right = torch.roll(cum_length, 1)
142142
cum_length_shift_right[0] = 0

0 commit comments

Comments
 (0)