Skip to content

Commit f5a06c8

Browse files
committed
Fixes the "nonzero is not supported for tensors with more than INT_MAX elements" in get_all_triplets_indices
1 parent 3a14f82 commit f5a06c8

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

src/pytorch_metric_learning/utils/loss_and_miner_utils.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,34 @@ def neg_pairs_from_tuple(indices_tuple):
8585

8686

8787
def get_all_triplets_indices(labels, ref_labels=None):
88-
matches, diffs = get_matches_and_diffs(labels, ref_labels)
89-
triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
90-
return torch.where(triplets)
88+
all_matches, all_diffs = get_matches_and_diffs(labels, ref_labels)
89+
all_matches, all_diffs = all_matches.bool(), all_diffs.bool()
90+
91+
# Find anchors with at least a positive and a negative
92+
indices = torch.arange(0, len(labels), device=labels.device)
93+
indices = indices[all_matches.any(dim=1) & all_diffs.any(dim=1)]
94+
95+
# No triplets found
96+
if len(indices) == 0:
97+
return (torch.tensor([], device=labels.device, dtype=labels.dtype),
98+
torch.tensor([], device=labels.device, dtype=labels.dtype),
99+
torch.tensor([], device=labels.device, dtype=labels.dtype))
100+
101+
# Compute all triplets
102+
anchors = []
103+
positives = []
104+
negatives = []
105+
for i in indices:
106+
matches = all_matches[i].nonzero(as_tuple=False).squeeze(1)
107+
diffs = all_diffs[i].nonzero(as_tuple=False).squeeze(1)
108+
nd = len(diffs)
109+
nm = len(matches)
110+
matches = matches.repeat_interleave(nd)
111+
diffs = diffs.repeat(nm)
112+
anchors.append(torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device))
113+
positives.append(matches)
114+
negatives.append(diffs)
115+
return torch.cat(anchors), torch.cat(positives), torch.cat(negatives)
91116

92117

93118
# sample triplets, with a weighted distribution if weights is specified.

0 commit comments

Comments
 (0)