|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include "common.cuh" |
| 10 | + |
| 11 | +using Tensor = at::Tensor; |
| 12 | + |
| 13 | +namespace fbgemm_gpu { |
| 14 | + |
| 15 | +template <typename index_t> |
| 16 | +__global__ |
| 17 | +__launch_bounds__(kMaxThreads) void masked_select_jagged_1d_lengths_kernel( |
| 18 | + const index_t* __restrict__ lengths, |
| 19 | + const bool* __restrict__ mask, |
| 20 | + index_t* __restrict__ masked_lengths, |
| 21 | + const index_t* __restrict__ input_offsets, |
| 22 | + const index_t batch_size) { |
| 23 | + const index_t batch_idx = blockIdx.x; |
| 24 | + |
| 25 | + if (batch_idx >= batch_size) { |
| 26 | + return; |
| 27 | + } |
| 28 | + |
| 29 | + const index_t input_offset = input_offsets[batch_idx]; |
| 30 | + const index_t input_len = lengths[batch_idx]; |
| 31 | + |
| 32 | + int32_t local_count = 0; |
| 33 | + for (index_t i = threadIdx.x; i < input_len; i += blockDim.x) { |
| 34 | + const index_t input_idx = input_offset + i; |
| 35 | + |
| 36 | + if (mask[input_idx]) { |
| 37 | + local_count++; |
| 38 | + } |
| 39 | + } |
| 40 | + |
| 41 | + __shared__ int32_t shared_counts[kMaxThreads]; |
| 42 | + shared_counts[threadIdx.x] = local_count; |
| 43 | + __syncthreads(); |
| 44 | + |
| 45 | + for (auto stride = blockDim.x / 2; stride > 0; stride /= 2) { |
| 46 | + if (threadIdx.x < stride) { |
| 47 | + shared_counts[threadIdx.x] += shared_counts[threadIdx.x + stride]; |
| 48 | + } |
| 49 | + __syncthreads(); |
| 50 | + } |
| 51 | + |
| 52 | + if (threadIdx.x == 0) { |
| 53 | + masked_lengths[batch_idx] = static_cast<index_t>(shared_counts[0]); |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +template <typename index_t, typename scalar_t> |
| 58 | +__global__ |
| 59 | +__launch_bounds__(kMaxThreads) void masked_select_jagged_1d_values_kernel( |
| 60 | + const scalar_t* __restrict__ values, |
| 61 | + const index_t* __restrict__ lengths, |
| 62 | + const bool* __restrict__ mask, |
| 63 | + scalar_t* __restrict__ masked_values, |
| 64 | + const index_t* __restrict__ input_offsets, |
| 65 | + const index_t* __restrict__ output_offsets, |
| 66 | + const index_t batch_size) { |
| 67 | + const index_t batch_idx = blockIdx.x; |
| 68 | + |
| 69 | + if (batch_idx >= batch_size) { |
| 70 | + return; |
| 71 | + } |
| 72 | + |
| 73 | + const index_t input_offset = input_offsets[batch_idx]; |
| 74 | + const index_t output_offset = output_offsets[batch_idx]; |
| 75 | + const index_t input_len = lengths[batch_idx]; |
| 76 | + |
| 77 | + int32_t write_pos = 0; |
| 78 | + |
| 79 | + for (index_t i = 0; i < input_len; i++) { |
| 80 | + const index_t input_idx = input_offset + i; |
| 81 | + |
| 82 | + const bool is_masked = mask[input_idx]; |
| 83 | + |
| 84 | + if (threadIdx.x == 0 && is_masked) { |
| 85 | + const index_t output_idx = output_offset + write_pos; |
| 86 | + |
| 87 | + masked_values[output_idx] = values[input_idx]; |
| 88 | + write_pos++; |
| 89 | + } |
| 90 | + } |
| 91 | +} |
| 92 | + |
| 93 | +std::tuple<Tensor, Tensor> masked_select_jagged_1d_cuda( |
| 94 | + const Tensor& values, |
| 95 | + const Tensor& lengths, |
| 96 | + const Tensor& mask, |
| 97 | + const std::optional<bool> check_length) { |
| 98 | + TENSOR_ON_CUDA_GPU(values); |
| 99 | + TENSOR_ON_CUDA_GPU(lengths); |
| 100 | + TENSOR_ON_CUDA_GPU(mask); |
| 101 | + |
| 102 | + TORCH_CHECK(values.dim() == 1); |
| 103 | + TORCH_CHECK(lengths.dim() == 1); |
| 104 | + TORCH_CHECK(mask.dim() == 1); |
| 105 | + |
| 106 | + if (check_length.has_value() && check_length.value()) { |
| 107 | + TORCH_CHECK( |
| 108 | + mask.numel() == values.numel(), |
| 109 | + "mask and values should have the same numel, but got mask numel: ", |
| 110 | + mask.numel(), |
| 111 | + " values numel: ", |
| 112 | + values.numel()); |
| 113 | + } |
| 114 | + |
| 115 | + const auto batch_size = lengths.numel(); |
| 116 | + Tensor masked_lengths = at::empty_like(lengths); |
| 117 | + |
| 118 | + if (batch_size == 0) { |
| 119 | + Tensor masked_values = at::empty({0}, values.options()); |
| 120 | + return {masked_values, masked_lengths}; |
| 121 | + } |
| 122 | + |
| 123 | + Tensor input_offsets = asynchronous_complete_cumsum_gpu(lengths); |
| 124 | + |
| 125 | + TORCH_CHECK( |
| 126 | + input_offsets.numel() == batch_size + 1, |
| 127 | + "input_offsets should have size batch_size+1, got ", |
| 128 | + input_offsets.numel(), |
| 129 | + " expected ", |
| 130 | + batch_size + 1); |
| 131 | + |
| 132 | + Tensor mask_int = mask.to(at::kInt); |
| 133 | + Tensor mask_cumsum = asynchronous_complete_cumsum_gpu(mask_int); |
| 134 | + const int32_t num_outputs = mask_cumsum[-1].item<int32_t>(); |
| 135 | + Tensor masked_values = at::empty({num_outputs}, values.options()); |
| 136 | + |
| 137 | + AT_DISPATCH_INDEX_TYPES( |
| 138 | + lengths.scalar_type(), "masked_select_jagged_1d_lengths", [&] { |
| 139 | + const int num_blocks = batch_size; |
| 140 | + // First pass: compute masked lengths |
| 141 | + FBGEMM_LAUNCH_KERNEL( |
| 142 | + (masked_select_jagged_1d_lengths_kernel<index_t>), |
| 143 | + num_blocks, |
| 144 | + kMaxThreads, |
| 145 | + 0, |
| 146 | + at::cuda::getCurrentCUDAStream(), |
| 147 | + lengths.data_ptr<index_t>(), |
| 148 | + mask.data_ptr<bool>(), |
| 149 | + masked_lengths.data_ptr<index_t>(), |
| 150 | + input_offsets.data_ptr<index_t>(), |
| 151 | + static_cast<index_t>(batch_size)); |
| 152 | + |
| 153 | + Tensor output_offsets = |
| 154 | + asynchronous_complete_cumsum_gpu(masked_lengths); |
| 155 | + |
| 156 | + TORCH_CHECK( |
| 157 | + output_offsets.numel() == batch_size + 1, |
| 158 | + "output_offsets should have size batch_size+1, got ", |
| 159 | + output_offsets.numel(), |
| 160 | + " expected ", |
| 161 | + batch_size + 1); |
| 162 | + |
| 163 | + // Second pass: write masked values |
| 164 | + FBGEMM_DISPATCH_ALL_TYPES( |
| 165 | + values.scalar_type(), "masked_select_jagged_1d_values", [&] { |
| 166 | + FBGEMM_LAUNCH_KERNEL( |
| 167 | + (masked_select_jagged_1d_values_kernel<index_t, scalar_t>), |
| 168 | + num_blocks, |
| 169 | + 1, // Use single thread per block for simplicity |
| 170 | + 0, |
| 171 | + at::cuda::getCurrentCUDAStream(), |
| 172 | + values.data_ptr<scalar_t>(), |
| 173 | + lengths.data_ptr<index_t>(), |
| 174 | + mask.data_ptr<bool>(), |
| 175 | + masked_values.data_ptr<scalar_t>(), |
| 176 | + input_offsets.data_ptr<index_t>(), |
| 177 | + output_offsets.data_ptr<index_t>(), |
| 178 | + static_cast<index_t>(batch_size)); |
| 179 | + }); |
| 180 | + }); |
| 181 | + |
| 182 | + return {masked_values, masked_lengths}; |
| 183 | +} |
| 184 | + |
| 185 | +} // namespace fbgemm_gpu |
| 186 | + |
| 187 | +FBGEMM_OP_DISPATCH( |
| 188 | + CUDA, |
| 189 | + "masked_select_jagged_1d", |
| 190 | + fbgemm_gpu::masked_select_jagged_1d_cuda); |
0 commit comments