Skip to content

Commit 1e47e3a

Browse files
authored
Fix handling of empty MultiNestedTensor (#369)
1 parent b8d17b6 commit 1e47e3a

File tree

5 files changed

+14
-25
lines changed

5 files changed

+14
-25
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626

2727
### Fixed
2828

29+
- Fixed bug in empty `MultiNestedTensor` handling ([#369](https://github.com/pyg-team/pytorch-frame/pull/369))
30+
2931
- Fixed the split of `DataFrameTextBenchmark` ([#358](https://github.com/pyg-team/pytorch-frame/pull/358))
3032
- Fixed empty `MultiNestedTensor` col indexing ([#355](https://github.com/pyg-team/pytorch-frame/pull/355))
3133

test/data/test_multi_nested_tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def test_multi_nested_tensor_basics(device):
180180
# Test multi_nested_tensor[List[int]] indexing
181181
for index in [[4], [2, 2], [-4, 1, 7], [3, -7, 1, 0], []]:
182182
multi_nested_tensor_indexed = multi_nested_tensor[index]
183+
assert multi_nested_tensor_indexed.dtype == torch.long
183184
assert multi_nested_tensor_indexed.shape[0] == len(index)
184185
assert multi_nested_tensor_indexed.shape[1] == num_cols
185186
for i, idx in enumerate(index):
@@ -208,8 +209,10 @@ def test_multi_nested_tensor_basics(device):
208209

209210
# Test column List[int] indexing
210211
for index in [[4], [2, 2], [-4, 1, 7], [3, -7, 1, 0], []]:
212+
multi_nested_tensor_indexed = multi_nested_tensor[:, index]
211213
assert_equal(column_select(tensor_mat, index),
212-
multi_nested_tensor[:, index])
214+
multi_nested_tensor_indexed)
215+
assert multi_nested_tensor_indexed.dtype == torch.long
213216

214217
# Test column-wise Boolean masking
215218
for index in [[4], [2, 3], [0, 1, 7], []]:
@@ -245,6 +248,7 @@ def test_multi_nested_tensor_basics(device):
245248
empty_multi_nested_tensor = multi_nested_tensor[:, 5:3]
246249
assert empty_multi_nested_tensor.shape[0] == num_rows
247250
assert empty_multi_nested_tensor.shape[1] == 0
251+
assert empty_multi_nested_tensor.dtype == torch.long
248252

249253
# Test column narrow
250254
assert_equal(column_select(tensor_mat, slice(3, 3 + 2)),

torch_frame/data/multi_embedding_tensor.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,7 @@ def _col_index_select(self, index: Tensor) -> MultiEmbeddingTensor:
150150
:meth:`MultiEmbeddingTensor.index_select`.
151151
"""
152152
if index.numel() == 0:
153-
return MultiEmbeddingTensor(
154-
num_rows=self.num_rows,
155-
num_cols=0,
156-
values=torch.tensor([], device=self.device),
157-
offset=torch.tensor([0], device=self.device),
158-
)
153+
return self._empty(dim=1)
159154
offset = torch.zeros(
160155
index.size(0) + 1,
161156
dtype=torch.long,
@@ -228,8 +223,8 @@ def _empty(self, dim: int) -> MultiEmbeddingTensor:
228223
return MultiEmbeddingTensor(
229224
num_rows=0 if dim == 0 else self.num_rows,
230225
num_cols=0 if dim == 1 else self.num_cols,
231-
values=torch.tensor([], device=self.device),
232-
offset=torch.tensor([0], device=self.device)
226+
values=torch.tensor([], device=self.device, dtype=self.dtype),
227+
offset=torch.tensor([0], device=self.device, dtype=torch.long)
233228
if dim == 1 else self.offset,
234229
)
235230

torch_frame/data/multi_nested_tensor.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,7 @@ def _row_index_select(self, index: Tensor) -> MultiNestedTensor:
180180
r"""Helper function called by :obj:`index_select`."""
181181
# Calculate values
182182
if index.numel() == 0:
183-
return MultiNestedTensor(
184-
num_rows=0,
185-
num_cols=self.num_cols,
186-
values=torch.tensor([], device=self.device),
187-
offset=torch.tensor([0], device=self.device),
188-
)
183+
return self._empty(dim=0)
189184
index_right = (index + 1) * self.num_cols
190185
index_left = index * self.num_cols
191186
diff = self.offset[index_right] - self.offset[index_left]
@@ -218,12 +213,7 @@ def _row_index_select(self, index: Tensor) -> MultiNestedTensor:
218213
def _col_index_select(self, index: Tensor) -> MultiNestedTensor:
219214
r"""Helper function called by :obj:`index_select`."""
220215
if index.numel() == 0:
221-
return MultiNestedTensor(
222-
num_rows=self.num_rows,
223-
num_cols=0,
224-
values=torch.tensor([], device=self.device),
225-
offset=torch.tensor([0], device=self.device),
226-
)
216+
return self._empty(dim=1)
227217
start_idx = (index + torch.arange(
228218
0,
229219
self.num_rows * self.num_cols,
@@ -320,13 +310,13 @@ def to_dense(self, fill_value: int | float) -> Tensor:
320310
return dense
321311

322312
def _empty(self, dim: int) -> MultiNestedTensor:
323-
r"""Creates an empty :class:`MultiEmbeddingTensor`.
313+
r"""Creates an empty :class:`MultiNestedTensor`.
324314
325315
Args:
326316
dim (int): The dimension to empty.
327317
328318
Returns:
329-
MultiEmbeddingTensor: An empty :class:`MultiEmbeddingTensor`.
319+
MultiNestedTensor: An empty :class:`MultiNestedTensor`.
330320
"""
331321
values = torch.tensor([], device=self.device, dtype=self.dtype)
332322
offset = torch.zeros(1, device=self.device, dtype=torch.long)

torch_frame/datasets/amphibians.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ def __init__(self, root: str):
5454
lambda row: [col for col in target_cols if row[col] == '1'],
5555
axis=1)
5656
df = df.drop(target_cols, axis=1)
57-
import pdb
58-
pdb.set_trace()
5957

6058
# Infer the pandas dataframe automatically
6159
path = osp.join(root, 'amphibians_posprocess.csv')

0 commit comments

Comments
 (0)