-
Notifications
You must be signed in to change notification settings - Fork 404
[TransferEngine]feat: add tensor transfer Read/Write API for transfer-engine #703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hi @Risc-lt |
@staryxchen Thx for response. We are currently working on experiments of object-level abstract for subtle migration. This is a draft to be improved. |
Hi, if you're interested. Feel free to join us. @staryxchen |
Okay. I will keep a close eye on this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds tensor-specific transfer functionality to the transfer engine, providing high-level APIs for PyTorch tensor serialization and deserialization over the network. The implementation wraps lower-level transfer operations with automatic metadata handling and tensor reconstruction.
Key changes include:
- New tensor transfer APIs (
transfer_tensor_sync_write
andtransfer_tensor_sync_read
) for PyTorch tensors - Automatic tensor metadata serialization including dtype, dimensions, and shape information
- Support for multiple tensor data types (float32, int32, bool, etc.) with up to 4 dimensions
- Comprehensive test suite covering various tensor types and scenarios
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
File | Description |
---|---|
test_transfer_tensor.py | Comprehensive test suite for tensor transfer functionality |
transfer_engine_py.h | Header definitions for tensor transfer APIs and metadata structures |
transfer_engine_py.cpp | Implementation of tensor serialization/deserialization logic |
|
||
#include <pybind11/stl.h> | ||
|
||
auto torch = py::module_::import("torch"); |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The global torch module import at file scope could cause issues if torch is not available when the module is loaded. Consider importing torch lazily within functions that need it, with proper error handling for cases where torch is not installed.
auto torch = py::module_::import("torch"); | |
// Lazy import torch within functions that need it, with error handling. | |
static py::object import_torch() { | |
try { | |
return py::module_::import("torch"); | |
} catch (const py::error_already_set &e) { | |
throw std::runtime_error("Failed to import torch Python module. Is torch installed? Error: " + std::string(e.what())); | |
} | |
} |
Copilot uses AI. Check for mistakes.
return TransferEnginePy{}.create_typed_array<uint64_t>(data, offset, total_length); | ||
}, // UINT64 = 9 | ||
[](char* data, size_t offset, size_t total_length) { | ||
return TransferEnginePy{}.create_typed_array<bool>(data, offset, total_length); |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each lambda creates a temporary TransferEnginePy object to call create_typed_array. This is inefficient and unnecessary since create_typed_array could be made static or the lambdas could directly implement the array creation logic.
return TransferEnginePy{}.create_typed_array<bool>(data, offset, total_length); | |
return TransferEnginePy::create_typed_array<float>(data, offset, total_length); | |
}, // FLOAT32 = 0 | |
[](char* data, size_t offset, size_t total_length) { | |
return TransferEnginePy::create_typed_array<double>(data, offset, total_length); | |
}, // FLOAT64 = 1 | |
[](char* data, size_t offset, size_t total_length) { | |
return TransferEnginePy::create_typed_array<int8_t>(data, offset, total_length); | |
}, // INT8 = 2 | |
[](char* data, size_t offset, size_t total_length) { | |
return TransferEnginePy::create_typed_array<uint8_t>(data, offset, total_length); | |
}, // UINT8 = 3 | |
[](char* data, size_t offset, size_t total_length) { | |
return TransferEnginePy::create_typed_array<int16_t>(data, offset, total_length); | |
}, // INT16 = 4 | |
[](char* data, size_t offset, size_t total_length) { | |
return TransferEnginePy::create_typed_array<uint16_t>(data, offset, total_length); | |
}, // UINT16 = 5 | |
[](char* data, size_t offset, size_t total_length) { | |
return TransferEnginePy::create_typed_array<int32_t>(data, offset, total_length); | |
}, // INT32 = 6 | |
[](char* data, size_t offset, size_t total_length) { | |
return TransferEnginePy::create_typed_array<uint32_t>(data, offset, total_length); | |
}, // UINT32 = 7 | |
[](char* data, size_t offset, size_t total_length) { | |
return TransferEnginePy::create_typed_array<int64_t>(data, offset, total_length); | |
}, // INT64 = 8 | |
[](char* data, size_t offset, size_t total_length) { | |
return TransferEnginePy::create_typed_array<uint64_t>(data, offset, total_length); | |
}, // UINT64 = 9 | |
[](char* data, size_t offset, size_t total_length) { | |
return TransferEnginePy::create_typed_array<bool>(data, offset, total_length); |
Copilot uses AI. Check for mistakes.
int32_t dtype; | ||
int32_t ndim; | ||
int32_t shape[4]; | ||
} __attribute__((packed)); |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fixed-size shape array limits tensors to 4 dimensions. Consider using a more flexible approach or document this limitation clearly, as PyTorch tensors can have more than 4 dimensions in practice.
} __attribute__((packed)); | |
std::vector<int32_t> shape; | |
}; |
Copilot uses AI. Check for mistakes.
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) | ||
|
||
# Calculate total size (metadata + tensor data) | ||
metadata_size = 24 # sizeof(TensorMetadata) = 4 * 4 bytes |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded metadata size of 24 bytes is fragile and could break if the TensorMetadata structure changes. Consider calculating this dynamically or defining it as a constant that can be imported from the C++ side.
metadata_size = 24 # sizeof(TensorMetadata) = 4 * 4 bytes | |
metadata_size = ctypes.sizeof(TensorMetadata) # Dynamically calculated |
Copilot uses AI. Check for mistakes.
if (i < ndim) { | ||
metadata.shape[i] = shape_tuple[i].cast<int32_t>(); | ||
} else { | ||
metadata.shape[i] = -1; |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Using -1 as a sentinel value for unused shape dimensions is unclear. Consider using 0 or defining a named constant to make the intent more explicit.
metadata.shape[i] = -1; | |
metadata.shape[i] = UNUSED_DIMENSION; |
Copilot uses AI. Check for mistakes.
for (int i = 0; i < metadata.ndim; i++) { | ||
if (metadata.shape[i] > 0) { // Only add valid dimensions | ||
shape_vec.push_back(metadata.shape[i]); | ||
} |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition metadata.shape[i] > 0
excludes dimensions with size 0, but zero-sized dimensions are valid in PyTorch tensors. This could cause incorrect tensor reconstruction for tensors with empty dimensions.
} | |
shape_vec.push_back(metadata.shape[i]); |
Copilot uses AI. Check for mistakes.
Add tensor-specific transfering handler to wrap up lower level executions of serialization and deserialization for pytorch tensor, memory registration and subtlr migration.
Plz review this. Thx. @stmatengss