Skip to content

Commit 80aa86a

Browse files
authored
Add a torch tensor contiguous check (#220)
Signed-off-by: Ti-Tai Wang <titaiwang@microsoft.com>
1 parent e499eb5 commit 80aa86a

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

src/onnx_ir/tensor_adapters_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,20 @@ def test_from_torch_dtype(
195195
actual = tensor_adapters.from_torch_dtype(torch_dtype)
196196
self.assertEqual(actual, expected_onnx_dtype)
197197

198+
def test_tofile_non_contiguous(self):
199+
base = torch.arange(0, 64, dtype=torch.int32).reshape(8, 8)
200+
sliced = base[:, ::2] # Stride in last dim -> non-contiguous
201+
self.assertFalse(sliced.is_contiguous())
202+
tensor = tensor_adapters.TorchTensor(sliced)
203+
# Ensure bytes correspond to the contiguous clone inside implementation
204+
expected_manual = sliced.contiguous().numpy().tobytes()
205+
with tempfile.NamedTemporaryFile() as tmp:
206+
tensor.tofile(tmp)
207+
tmp.seek(0)
208+
data = tmp.read()
209+
self.assertEqual(data, expected_manual)
210+
self.assertEqual(tensor.tobytes(), expected_manual)
211+
198212

199213
if __name__ == "__main__":
200214
unittest.main()

0 commit comments

Comments
 (0)