From 7676946aa4ee4da988c97847951b263c5766c525 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 5 Jan 2026 11:03:55 -0800 Subject: [PATCH] [slimtensor] Add common_shims_slim with basic property getters Add SlimTensor-based implementations of basic property getter AOTI shim functions: 1. `aoti_torch_get_data_ptr()` - Returns pointer to tensor data 2. `aoti_torch_get_sizes()` - Returns pointer to sizes array (SlimTensor stores int64_t directly) 3. `aoti_torch_get_strides()` - Returns pointer to strides array (SlimTensor stores int64_t directly) 4. `aoti_torch_get_dtype()` - Returns the scalar type as int32_t 5. `aoti_torch_get_dim()` - Returns the number of dimensions Key design: - Create a new common_shim_slim.h for working on new API while not impact the current pipeline. Will use common_shim_slim.{h/cpp} to replace current common_shim.{h/cpp} when everything has been set up. - Uses `#ifdef CUDA_AVAILABLE` conditional compilation to seperate the implementation between cuda backend and mps backend since SlimTensor hasn't have mps support yet. Will remove the branch once SlimTensor support mps. - Refactored to a header-only library so the caller's preprocessor flags determine which tensor type is used. This design supports both CUDA backend (SlimTensor) and MPS backend (ETensor) from a single library. Differential Revision: [D90126254](https://our.internmc.facebook.com/intern/diff/D90126254/) [ghstack-poisoned] --- backends/aoti/common_shims_slim.h | 215 ++++++++ backends/aoti/targets.bzl | 18 + backends/aoti/tests/TARGETS | 25 + .../aoti/tests/test_common_shims_slim.cpp | 460 ++++++++++++++++++ 4 files changed, 718 insertions(+) create mode 100644 backends/aoti/common_shims_slim.h create mode 100644 backends/aoti/tests/test_common_shims_slim.cpp diff --git a/backends/aoti/common_shims_slim.h b/backends/aoti/common_shims_slim.h new file mode 100644 index 00000000000..ad820e69739 --- /dev/null +++ b/backends/aoti/common_shims_slim.h @@ -0,0 +1,215 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +// Uses conditional compilation to separate the implementation between +// CUDA backend (SlimTensor) and other backends like MPS (ETensor). +// The caller determines which path is used by defining CUDA_AVAILABLE. +#ifdef CUDA_AVAILABLE +#include +#else +#include +#endif + +namespace executorch { +namespace backends { +namespace aoti { + +// Common using declarations for ExecuTorch types +using executorch::runtime::Error; + +// ============================================================ +// Tensor Type Definition - branched based on CUDA_AVAILABLE +// ============================================================ +#ifdef CUDA_AVAILABLE +using Tensor = executorch::backends::aoti::slim::SlimTensor; +#else +using Tensor = executorch::runtime::etensor::Tensor; +#endif + +// Common AOTI type aliases +using AOTIRuntimeError = Error; +using AOTITorchError = Error; + +#ifndef CUDA_AVAILABLE +namespace internal { +// Global storage for tensor metadata (ETensor path only) +// SlimTensor stores sizes/strides directly in int64_t[] - no caching needed +inline std::unordered_map>& tensor_to_sizes() { + static std::unordered_map> instance; + return instance; +} +inline std::unordered_map>& tensor_to_strides() { + static std::unordered_map> instance; + return instance; +} +} // namespace internal +#endif + +// ============================================================ +// Basic Property Getters - Inline implementations +// ============================================================ + +inline AOTITorchError aoti_torch_get_data_ptr( + Tensor* tensor, + void** ret_data_ptr) { + if (tensor == nullptr) { + return Error::InvalidArgument; + } + if (ret_data_ptr == nullptr) { + return Error::InvalidArgument; + } + +#ifdef CUDA_AVAILABLE + *ret_data_ptr = tensor->data_ptr(); +#else + *ret_data_ptr = tensor->mutable_data_ptr(); +#endif + return Error::Ok; +} + +inline AOTITorchError aoti_torch_get_sizes( + Tensor* tensor, + int64_t** ret_sizes) { + if (tensor == nullptr) { + return Error::InvalidArgument; + } + if (ret_sizes == nullptr) { + return Error::InvalidArgument; + } + +#ifdef CUDA_AVAILABLE + // SlimTensor stores sizes directly in int64_t[] - no caching needed + *ret_sizes = const_cast(tensor->sizes().data()); +#else + auto it = internal::tensor_to_sizes().find(tensor); + bool needs_update = false; + + if (it == internal::tensor_to_sizes().end()) { + needs_update = true; + } else { + // Validate cached metadata matches current tensor state + auto tensor_sizes = tensor->sizes(); + needs_update = !std::equal( + it->second.begin(), + it->second.end(), + tensor_sizes.begin(), + tensor_sizes.end()); + } + + if (needs_update) { + std::vector sizes(tensor->dim()); + auto tensor_sizes = tensor->sizes(); + for (int i = 0; i < tensor->dim(); i++) { + sizes[i] = tensor_sizes[i]; + } + it = internal::tensor_to_sizes() + .insert_or_assign(tensor, std::move(sizes)) + .first; + } + + // For 0D tensors, data() returns nullptr on empty vectors + if (it->second.empty()) { + static int64_t empty_sizes_placeholder = 0; + *ret_sizes = &empty_sizes_placeholder; + } else { + *ret_sizes = it->second.data(); + } +#endif + return Error::Ok; +} + +inline AOTITorchError aoti_torch_get_strides( + Tensor* tensor, + int64_t** ret_strides) { + if (tensor == nullptr) { + return Error::InvalidArgument; + } + if (ret_strides == nullptr) { + return Error::InvalidArgument; + } + +#ifdef CUDA_AVAILABLE + // SlimTensor stores strides directly in int64_t[] - no caching needed + *ret_strides = const_cast(tensor->strides().data()); +#else + auto it = internal::tensor_to_strides().find(tensor); + bool needs_update = false; + + if (it == internal::tensor_to_strides().end()) { + needs_update = true; + } else { + // Validate cached metadata matches current tensor state + auto tensor_strides = tensor->strides(); + needs_update = !std::equal( + it->second.begin(), + it->second.end(), + tensor_strides.begin(), + tensor_strides.end()); + } + + if (needs_update) { + std::vector strides(tensor->dim()); + auto tensor_strides = tensor->strides(); + for (int i = 0; i < tensor->dim(); i++) { + strides[i] = tensor_strides[i]; + } + it = internal::tensor_to_strides() + .insert_or_assign(tensor, std::move(strides)) + .first; + } + + // For 0D tensors, data() returns nullptr on empty vectors + if (it->second.empty()) { + static int64_t empty_strides_placeholder = 0; + *ret_strides = &empty_strides_placeholder; + } else { + *ret_strides = it->second.data(); + } +#endif + return Error::Ok; +} + +inline AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) { + if (tensor == nullptr) { + return Error::InvalidArgument; + } + if (ret_dtype == nullptr) { + return Error::InvalidArgument; + } + +#ifdef CUDA_AVAILABLE + *ret_dtype = static_cast(tensor->dtype()); +#else + *ret_dtype = static_cast(tensor->scalar_type()); +#endif + return Error::Ok; +} + +inline AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) { + if (tensor == nullptr) { + return Error::InvalidArgument; + } + if (ret_dim == nullptr) { + return Error::InvalidArgument; + } + + *ret_dim = static_cast(tensor->dim()); + return Error::Ok; +} + +} // namespace aoti +} // namespace backends +} // namespace executorch diff --git a/backends/aoti/targets.bzl b/backends/aoti/targets.bzl index 327bef8cc53..4f493437c2e 100644 --- a/backends/aoti/targets.bzl +++ b/backends/aoti/targets.bzl @@ -86,3 +86,21 @@ def define_common_targets(): ":delegate_handle", ], ) + + # SlimTensor-based common shims (header-only library) + # The caller determines which tensor type is used by defining CUDA_AVAILABLE. + # - With CUDA_AVAILABLE=1: Uses SlimTensor + # - Without CUDA_AVAILABLE: Uses ETensor + runtime.cxx_library( + name = "common_shims_slim", + headers = [ + "common_shims_slim.h", + "export.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/backends/aoti/slim/core:slimtensor", + ], + ) diff --git a/backends/aoti/tests/TARGETS b/backends/aoti/tests/TARGETS index 8daa8abd4d7..d92e0e32a1f 100644 --- a/backends/aoti/tests/TARGETS +++ b/backends/aoti/tests/TARGETS @@ -1,4 +1,5 @@ load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") oncall("executorch") @@ -20,3 +21,27 @@ cpp_unittest( "//executorch/extension/tensor:tensor", ], ) + +cpp_unittest( + name = "test_common_shims_slim", + srcs = [ + "test_common_shims_slim.cpp", + ], + deps = [ + "//executorch/backends/aoti:common_shims_slim", + "//executorch/backends/aoti/slim/core:slimtensor", + "//executorch/backends/aoti/slim/factory:empty", + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + preprocessor_flags = [ + "-DCUDA_AVAILABLE=1", + ], + keep_gpu_sections = True, + remote_execution = re_test_utils.remote_execution( + platform = "gpu-remote-execution", + ), +) diff --git a/backends/aoti/tests/test_common_shims_slim.cpp b/backends/aoti/tests/test_common_shims_slim.cpp new file mode 100644 index 00000000000..2e4bfa63286 --- /dev/null +++ b/backends/aoti/tests/test_common_shims_slim.cpp @@ -0,0 +1,460 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#ifdef CUDA_AVAILABLE +#include +#endif + +using namespace executorch::backends::aoti; +using executorch::runtime::Error; + +namespace slim_c10 = executorch::backends::aoti::slim::c10; +namespace slim = executorch::backends::aoti::slim; + +namespace { + +#ifdef CUDA_AVAILABLE +bool isCudaAvailable() { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + return (err == cudaSuccess && device_count > 0); +} +#endif + +// Helper to calculate contiguous strides from sizes +std::vector calculateContiguousStrides( + const std::vector& sizes) { + std::vector strides(sizes.size()); + if (sizes.empty()) { + return strides; + } + strides[sizes.size() - 1] = 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + return strides; +} + +} // namespace + +// Test fixture for common_shims_slim tests +class CommonShimsSlimTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + } + + void TearDown() override { + // Cleanup tracked tensors + for (Tensor* t : tensors_) { + delete t; + } + tensors_.clear(); + } + + void trackTensor(Tensor* t) { + if (t != nullptr) { + tensors_.push_back(t); + } + } + + Tensor* createTestTensor( + const std::vector& sizes, + slim_c10::DeviceType device_type) { + std::vector strides = calculateContiguousStrides(sizes); + slim_c10::Device device(device_type, 0); + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + trackTensor(tensor); + return tensor; + } + + private: + std::vector tensors_; +}; + +// ============================================================================ +// Common test body implementations - parameterized by device type +// ============================================================================ + +void runGetDataPtrTest(slim_c10::DeviceType device_type) { + std::vector sizes = {2, 3}; + std::vector strides = calculateContiguousStrides(sizes); + slim_c10::Device device(device_type, 0); + + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + + void* data_ptr = nullptr; + AOTITorchError error = aoti_torch_get_data_ptr(tensor, &data_ptr); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(data_ptr, nullptr); + + // Verify the returned pointer matches tensor's data_ptr + EXPECT_EQ(data_ptr, tensor->data_ptr()); + + delete tensor; +} + +void runGetSizesTest(slim_c10::DeviceType device_type) { + std::vector sizes = {2, 3, 4}; + std::vector strides = calculateContiguousStrides(sizes); + slim_c10::Device device(device_type, 0); + + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + + int64_t* ret_sizes = nullptr; + AOTITorchError error = aoti_torch_get_sizes(tensor, &ret_sizes); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(ret_sizes, nullptr); + + // Verify sizes match + EXPECT_EQ(ret_sizes[0], 2); + EXPECT_EQ(ret_sizes[1], 3); + EXPECT_EQ(ret_sizes[2], 4); + + delete tensor; +} + +void runGetStridesTest(slim_c10::DeviceType device_type) { + std::vector sizes = {2, 3, 4}; + std::vector strides = calculateContiguousStrides(sizes); + slim_c10::Device device(device_type, 0); + + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + + int64_t* ret_strides = nullptr; + AOTITorchError error = aoti_torch_get_strides(tensor, &ret_strides); + + EXPECT_EQ(error, Error::Ok); + EXPECT_NE(ret_strides, nullptr); + + // Verify strides match: [12, 4, 1] for contiguous [2, 3, 4] + EXPECT_EQ(ret_strides[0], 12); + EXPECT_EQ(ret_strides[1], 4); + EXPECT_EQ(ret_strides[2], 1); + + delete tensor; +} + +void runGetDtypeTest(slim_c10::DeviceType device_type) { + std::vector sizes = {2, 3}; + std::vector strides = calculateContiguousStrides(sizes); + slim_c10::Device device(device_type, 0); + + // Test Float32 + { + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + + int32_t ret_dtype = -1; + AOTITorchError error = aoti_torch_get_dtype(tensor, &ret_dtype); + + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(ret_dtype, static_cast(slim_c10::ScalarType::Float)); + + delete tensor; + } + + // Test Int64 + { + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Long, + device)); + + int32_t ret_dtype = -1; + AOTITorchError error = aoti_torch_get_dtype(tensor, &ret_dtype); + + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(ret_dtype, static_cast(slim_c10::ScalarType::Long)); + + delete tensor; + } + + // Test BFloat16 + { + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::BFloat16, + device)); + + int32_t ret_dtype = -1; + AOTITorchError error = aoti_torch_get_dtype(tensor, &ret_dtype); + + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(ret_dtype, static_cast(slim_c10::ScalarType::BFloat16)); + + delete tensor; + } +} + +void runGetDimTest(slim_c10::DeviceType device_type) { + slim_c10::Device device(device_type, 0); + + // Test 0D tensor (scalar) + { + std::vector sizes = {}; + std::vector strides = {}; + + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + + int64_t ret_dim = -1; + AOTITorchError error = aoti_torch_get_dim(tensor, &ret_dim); + + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(ret_dim, 0); + + delete tensor; + } + + // Test 1D tensor + { + std::vector sizes = {5}; + std::vector strides = calculateContiguousStrides(sizes); + + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + + int64_t ret_dim = -1; + AOTITorchError error = aoti_torch_get_dim(tensor, &ret_dim); + + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(ret_dim, 1); + + delete tensor; + } + + // Test 3D tensor + { + std::vector sizes = {2, 3, 4}; + std::vector strides = calculateContiguousStrides(sizes); + + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + + int64_t ret_dim = -1; + AOTITorchError error = aoti_torch_get_dim(tensor, &ret_dim); + + EXPECT_EQ(error, Error::Ok); + EXPECT_EQ(ret_dim, 3); + + delete tensor; + } +} + +// ============================================================================ +// CPU Tests +// ============================================================================ + +TEST_F(CommonShimsSlimTest, GetDataPtr_CPU) { + runGetDataPtrTest(slim_c10::DeviceType::CPU); +} + +TEST_F(CommonShimsSlimTest, GetSizes_CPU) { + runGetSizesTest(slim_c10::DeviceType::CPU); +} + +TEST_F(CommonShimsSlimTest, GetStrides_CPU) { + runGetStridesTest(slim_c10::DeviceType::CPU); +} + +TEST_F(CommonShimsSlimTest, GetDtype_CPU) { + runGetDtypeTest(slim_c10::DeviceType::CPU); +} + +TEST_F(CommonShimsSlimTest, GetDim_CPU) { + runGetDimTest(slim_c10::DeviceType::CPU); +} + +// ============================================================================ +// CUDA Tests +// ============================================================================ + +#ifdef CUDA_AVAILABLE +TEST_F(CommonShimsSlimTest, GetDataPtr_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runGetDataPtrTest(slim_c10::DeviceType::CUDA); +} + +TEST_F(CommonShimsSlimTest, GetSizes_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runGetSizesTest(slim_c10::DeviceType::CUDA); +} + +TEST_F(CommonShimsSlimTest, GetStrides_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runGetStridesTest(slim_c10::DeviceType::CUDA); +} + +TEST_F(CommonShimsSlimTest, GetDtype_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runGetDtypeTest(slim_c10::DeviceType::CUDA); +} + +TEST_F(CommonShimsSlimTest, GetDim_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runGetDimTest(slim_c10::DeviceType::CUDA); +} +#endif + +// ============================================================================ +// Error Cases +// ============================================================================ + +TEST_F(CommonShimsSlimTest, NullTensorArgument) { + void* data_ptr = nullptr; + int64_t* sizes = nullptr; + int64_t* strides = nullptr; + int32_t dtype = -1; + int64_t dim = -1; + + EXPECT_EQ( + aoti_torch_get_data_ptr(nullptr, &data_ptr), Error::InvalidArgument); + EXPECT_EQ(aoti_torch_get_sizes(nullptr, &sizes), Error::InvalidArgument); + EXPECT_EQ(aoti_torch_get_strides(nullptr, &strides), Error::InvalidArgument); + EXPECT_EQ(aoti_torch_get_dtype(nullptr, &dtype), Error::InvalidArgument); + EXPECT_EQ(aoti_torch_get_dim(nullptr, &dim), Error::InvalidArgument); +} + +TEST_F(CommonShimsSlimTest, NullReturnPointer) { + Tensor* tensor = createTestTensor({2, 3}, slim_c10::DeviceType::CPU); + + EXPECT_EQ(aoti_torch_get_data_ptr(tensor, nullptr), Error::InvalidArgument); + EXPECT_EQ(aoti_torch_get_sizes(tensor, nullptr), Error::InvalidArgument); + EXPECT_EQ(aoti_torch_get_strides(tensor, nullptr), Error::InvalidArgument); + EXPECT_EQ(aoti_torch_get_dtype(tensor, nullptr), Error::InvalidArgument); + EXPECT_EQ(aoti_torch_get_dim(tensor, nullptr), Error::InvalidArgument); +} + +// ============================================================================ +// Edge Cases +// ============================================================================ + +TEST_F(CommonShimsSlimTest, ScalarTensor) { + std::vector sizes = {}; + std::vector strides = {}; + slim_c10::Device device(slim_c10::DeviceType::CPU, 0); + + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + trackTensor(tensor); + + // Get sizes and strides for 0D tensor + int64_t* ret_sizes = nullptr; + int64_t* ret_strides = nullptr; + int64_t ret_dim = -1; + + EXPECT_EQ(aoti_torch_get_sizes(tensor, &ret_sizes), Error::Ok); + EXPECT_NE(ret_sizes, nullptr); + + EXPECT_EQ(aoti_torch_get_strides(tensor, &ret_strides), Error::Ok); + EXPECT_NE(ret_strides, nullptr); + + EXPECT_EQ(aoti_torch_get_dim(tensor, &ret_dim), Error::Ok); + EXPECT_EQ(ret_dim, 0); +} + +TEST_F(CommonShimsSlimTest, LargeTensor) { + std::vector sizes = {100, 200, 300}; + std::vector strides = calculateContiguousStrides(sizes); + slim_c10::Device device(slim_c10::DeviceType::CPU, 0); + + Tensor* tensor = new Tensor(slim::empty_strided( + slim::makeArrayRef(sizes), + slim::makeArrayRef(strides), + slim_c10::ScalarType::Float, + device)); + trackTensor(tensor); + + int64_t* ret_sizes = nullptr; + int64_t* ret_strides = nullptr; + + EXPECT_EQ(aoti_torch_get_sizes(tensor, &ret_sizes), Error::Ok); + EXPECT_EQ(ret_sizes[0], 100); + EXPECT_EQ(ret_sizes[1], 200); + EXPECT_EQ(ret_sizes[2], 300); + + EXPECT_EQ(aoti_torch_get_strides(tensor, &ret_strides), Error::Ok); + EXPECT_EQ(ret_strides[0], 60000); // 200 * 300 + EXPECT_EQ(ret_strides[1], 300); // 300 + EXPECT_EQ(ret_strides[2], 1); +} + +TEST_F(CommonShimsSlimTest, ConsistentPointerReturn) { + Tensor* tensor = createTestTensor({2, 3, 4}, slim_c10::DeviceType::CPU); + + // Multiple calls should return the same pointer (for SlimTensor) + int64_t* sizes_ptr1 = nullptr; + int64_t* sizes_ptr2 = nullptr; + + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr1), Error::Ok); + EXPECT_EQ(aoti_torch_get_sizes(tensor, &sizes_ptr2), Error::Ok); + EXPECT_EQ(sizes_ptr1, sizes_ptr2); + + int64_t* strides_ptr1 = nullptr; + int64_t* strides_ptr2 = nullptr; + + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr1), Error::Ok); + EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr2), Error::Ok); + EXPECT_EQ(strides_ptr1, strides_ptr2); +}