Skip to content

Commit cfafd3b

Browse files
Merge pull request #689 from mkmenta/remove-big-tensor-nonzero
Fixes the "nonzero is not supported for tensors with more than INT_MAX elements" in get_all_triplets_indices
2 parents dd40036 + b5383c9 commit cfafd3b

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

src/pytorch_metric_learning/utils/loss_and_miner_utils.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,41 @@ def neg_pairs_from_tuple(indices_tuple):
8585

8686

8787
def get_all_triplets_indices(labels, ref_labels=None):
88-
matches, diffs = get_matches_and_diffs(labels, ref_labels)
89-
triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
90-
return torch.where(triplets)
88+
all_matches, all_diffs = get_matches_and_diffs(labels, ref_labels)
89+
90+
if (all_matches.shape[0] * all_matches.shape[1] * all_matches.shape[1]
91+
< torch.iinfo(torch.int32).max):
92+
# torch.nonzero is not supported for tensors with more than INT_MAX elements
93+
triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)
94+
return torch.where(triplets)
95+
96+
all_matches, all_diffs = all_matches.bool(), all_diffs.bool()
97+
98+
# Find anchors with at least a positive and a negative
99+
indices = torch.arange(0, len(labels), device=labels.device)
100+
indices = indices[all_matches.any(dim=1) & all_diffs.any(dim=1)]
101+
102+
# No triplets found
103+
if len(indices) == 0:
104+
return (torch.tensor([], device=labels.device, dtype=labels.dtype),
105+
torch.tensor([], device=labels.device, dtype=labels.dtype),
106+
torch.tensor([], device=labels.device, dtype=labels.dtype))
107+
108+
# Compute all triplets
109+
anchors = []
110+
positives = []
111+
negatives = []
112+
for i in indices:
113+
matches = all_matches[i].nonzero(as_tuple=False).squeeze(1)
114+
diffs = all_diffs[i].nonzero(as_tuple=False).squeeze(1)
115+
nd = len(diffs)
116+
nm = len(matches)
117+
matches = matches.repeat_interleave(nd)
118+
diffs = diffs.repeat(nm)
119+
anchors.append(torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device))
120+
positives.append(matches)
121+
negatives.append(diffs)
122+
return torch.cat(anchors), torch.cat(positives), torch.cat(negatives)
91123

92124

93125
# sample triplets, with a weighted distribution if weights is specified.

0 commit comments

Comments
 (0)