diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index 51375f0a64..2c8a9c7c4a 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -495,22 +495,25 @@ def ref(values: torch.Tensor, lengths: torch.Tensor, max_len: int) -> torch.Tens @cli.command() @click.option("--batch-size", type=int, default=1024) @click.option("--max-len", type=int, default=256) +@click.option("--use-cpu", is_flag=True, default=False) def masked_select_jagged_1d( batch_size: int, max_len: int, + use_cpu: bool, ) -> None: - lengths = torch.randint(2 * max_len, size=(batch_size,)) # Allow for truncation + device = "cpu" if use_cpu else "cuda" + lengths = torch.randint(2 * max_len, size=(batch_size,), device=device) total_lengths = int(lengths.sum().item()) dtype = torch.long - values = torch.randint(2**16, (total_lengths,), dtype=dtype) - mask = torch.randint(2, (total_lengths,)) > 0 + values = torch.randint(2**16, (total_lengths,), dtype=dtype, device=device) + mask = torch.randint(2, (total_lengths,), device=device) > 0 def ref( values: torch.Tensor, lengths: torch.Tensor, mask: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: masked_values_ref = values[mask] cum_count = torch.cumsum(mask, 0) - cum_count = torch.cat((cum_count, torch.tensor([0]))) + cum_count = torch.cat((cum_count, torch.tensor([0], device=values.device))) cum_length = cum_count[torch.cumsum(lengths, 0) - 1] cum_length_shift_right = torch.roll(cum_length, 1) cum_length_shift_right[0] = 0 @@ -532,8 +535,10 @@ def ref( bytes = (2 * values.numel() + 2 * lengths.numel() + 2 * masked_values.numel()) * 4 - logging.info(f"reference {time_ref} sec {bytes / time_ref / 1e9} GB/s") - logging.info(f"masked_select_jagged_1d {time} sec {bytes / time / 1e9} GB/s") + logging.info(f"[{device}] reference {time_ref} sec {bytes / time_ref / 1e9} GB/s") + logging.info( + f"[{device}] masked_select_jagged_1d {time} sec {bytes / time / 1e9} GB/s" + ) @cli.command() diff --git a/fbgemm_gpu/src/jagged_tensor_ops/masked_select_jagged_1d.cu b/fbgemm_gpu/src/jagged_tensor_ops/masked_select_jagged_1d.cu new file mode 100644 index 0000000000..7b7aea71c8 --- /dev/null +++ b/fbgemm_gpu/src/jagged_tensor_ops/masked_select_jagged_1d.cu @@ -0,0 +1,190 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "common.cuh" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +template +__global__ +__launch_bounds__(kMaxThreads) void masked_select_jagged_1d_lengths_kernel( + const index_t* __restrict__ lengths, + const bool* __restrict__ mask, + index_t* __restrict__ masked_lengths, + const index_t* __restrict__ input_offsets, + const index_t batch_size) { + const index_t batch_idx = blockIdx.x; + + if (batch_idx >= batch_size) { + return; + } + + const index_t input_offset = input_offsets[batch_idx]; + const index_t input_len = lengths[batch_idx]; + + int32_t local_count = 0; + for (index_t i = threadIdx.x; i < input_len; i += blockDim.x) { + const index_t input_idx = input_offset + i; + + if (mask[input_idx]) { + local_count++; + } + } + + __shared__ int32_t shared_counts[kMaxThreads]; + shared_counts[threadIdx.x] = local_count; + __syncthreads(); + + for (auto stride = blockDim.x / 2; stride > 0; stride /= 2) { + if (threadIdx.x < stride) { + shared_counts[threadIdx.x] += shared_counts[threadIdx.x + stride]; + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + masked_lengths[batch_idx] = static_cast(shared_counts[0]); + } +} + +template +__global__ +__launch_bounds__(kMaxThreads) void masked_select_jagged_1d_values_kernel( + const scalar_t* __restrict__ values, + const index_t* __restrict__ lengths, + const bool* __restrict__ mask, + scalar_t* __restrict__ masked_values, + const index_t* __restrict__ input_offsets, + const index_t* __restrict__ output_offsets, + const index_t batch_size) { + const index_t batch_idx = blockIdx.x; + + if (batch_idx >= batch_size) { + return; + } + + const index_t input_offset = input_offsets[batch_idx]; + const index_t output_offset = output_offsets[batch_idx]; + const index_t input_len = lengths[batch_idx]; + + int32_t write_pos = 0; + + for (index_t i = 0; i < input_len; i++) { + const index_t input_idx = input_offset + i; + + const bool is_masked = mask[input_idx]; + + if (threadIdx.x == 0 && is_masked) { + const index_t output_idx = output_offset + write_pos; + + masked_values[output_idx] = values[input_idx]; + write_pos++; + } + } +} + +std::tuple masked_select_jagged_1d_cuda( + const Tensor& values, + const Tensor& lengths, + const Tensor& mask, + const std::optional check_length) { + TENSOR_ON_CUDA_GPU(values); + TENSOR_ON_CUDA_GPU(lengths); + TENSOR_ON_CUDA_GPU(mask); + + TORCH_CHECK(values.dim() == 1); + TORCH_CHECK(lengths.dim() == 1); + TORCH_CHECK(mask.dim() == 1); + + if (check_length.has_value() && check_length.value()) { + TORCH_CHECK( + mask.numel() == values.numel(), + "mask and values should have the same numel, but got mask numel: ", + mask.numel(), + " values numel: ", + values.numel()); + } + + const auto batch_size = lengths.numel(); + Tensor masked_lengths = at::empty_like(lengths); + + if (batch_size == 0) { + Tensor masked_values = at::empty({0}, values.options()); + return {masked_values, masked_lengths}; + } + + Tensor input_offsets = asynchronous_complete_cumsum_gpu(lengths); + + TORCH_CHECK( + input_offsets.numel() == batch_size + 1, + "input_offsets should have size batch_size+1, got ", + input_offsets.numel(), + " expected ", + batch_size + 1); + + Tensor mask_int = mask.to(at::kInt); + Tensor mask_cumsum = asynchronous_complete_cumsum_gpu(mask_int); + const int32_t num_outputs = mask_cumsum[-1].item(); + Tensor masked_values = at::empty({num_outputs}, values.options()); + + AT_DISPATCH_INDEX_TYPES( + lengths.scalar_type(), "masked_select_jagged_1d_lengths", [&] { + const int num_blocks = batch_size; + // First pass: compute masked lengths + FBGEMM_LAUNCH_KERNEL( + (masked_select_jagged_1d_lengths_kernel), + num_blocks, + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream(), + lengths.data_ptr(), + mask.data_ptr(), + masked_lengths.data_ptr(), + input_offsets.data_ptr(), + static_cast(batch_size)); + + Tensor output_offsets = + asynchronous_complete_cumsum_gpu(masked_lengths); + + TORCH_CHECK( + output_offsets.numel() == batch_size + 1, + "output_offsets should have size batch_size+1, got ", + output_offsets.numel(), + " expected ", + batch_size + 1); + + // Second pass: write masked values + FBGEMM_DISPATCH_ALL_TYPES( + values.scalar_type(), "masked_select_jagged_1d_values", [&] { + FBGEMM_LAUNCH_KERNEL( + (masked_select_jagged_1d_values_kernel), + num_blocks, + 1, // Use single thread per block for simplicity + 0, + at::cuda::getCurrentCUDAStream(), + values.data_ptr(), + lengths.data_ptr(), + mask.data_ptr(), + masked_values.data_ptr(), + input_offsets.data_ptr(), + output_offsets.data_ptr(), + static_cast(batch_size)); + }); + }); + + return {masked_values, masked_lengths}; +} + +} // namespace fbgemm_gpu + +FBGEMM_OP_DISPATCH( + CUDA, + "masked_select_jagged_1d", + fbgemm_gpu::masked_select_jagged_1d_cuda); diff --git a/fbgemm_gpu/test/jagged/misc_ops_test.py b/fbgemm_gpu/test/jagged/misc_ops_test.py index 4df88bb285..4578fdd6f7 100644 --- a/fbgemm_gpu/test/jagged/misc_ops_test.py +++ b/fbgemm_gpu/test/jagged/misc_ops_test.py @@ -95,7 +95,7 @@ def test_jagged_1d_to_truncated_values( index_dtype=st.sampled_from([torch.int, torch.long]), jagged_tensor_dtype=st.sampled_from([torch.int, torch.long]), empty_lengths=st.booleans(), - use_cpu=st.just(True), + use_cpu=st.booleans() if torch.cuda.is_available() else st.just(True), ) @settings(max_examples=20, deadline=None) def test_masked_select_jagged_1d( @@ -118,7 +118,7 @@ def test_masked_select_jagged_1d( dtype=index_dtype, device=device, ) - lengths[batch_size // 2] = 0 # test a corner case + lengths[batch_size // 2] = 0 n = int(lengths.sum().item()) values = torch.randint( 2**16, @@ -126,7 +126,7 @@ def test_masked_select_jagged_1d( dtype=jagged_tensor_dtype, device=device, ) - mask = torch.randint(2, (n,)) > 0 + mask = torch.randint(2, (n,), device=device) > 0 masked_values, masked_lengths = torch.ops.fbgemm.masked_select_jagged_1d( values, @@ -136,7 +136,7 @@ def test_masked_select_jagged_1d( masked_values_ref = values[mask] cum_count = torch.cumsum(mask, 0) - cum_count = torch.cat((cum_count, torch.tensor([0]))) + cum_count = torch.cat((cum_count, torch.tensor([0], device=device))) cum_length = cum_count[torch.cumsum(lengths, 0) - 1] cum_length_shift_right = torch.roll(cum_length, 1) cum_length_shift_right[0] = 0