Skip to content

Commit b5383c9

Browse files
committed
Add original approach for cases with < INT_MAX elements in get_all_triplets_indices
1 parent f5a06c8 commit b5383c9

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/pytorch_metric_learning/utils/loss_and_miner_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ def neg_pairs_from_tuple(indices_tuple):
8686

8787
def get_all_triplets_indices(labels, ref_labels=None):
8888
all_matches, all_diffs = get_matches_and_diffs(labels, ref_labels)
89+
90+
if (all_matches.shape[0] * all_matches.shape[1] * all_matches.shape[1]
91+
< torch.iinfo(torch.int32).max):
92+
# torch.nonzero is not supported for tensors with more than INT_MAX elements
93+
triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)
94+
return torch.where(triplets)
95+
8996
all_matches, all_diffs = all_matches.bool(), all_diffs.bool()
9097

9198
# Find anchors with at least a positive and a negative

0 commit comments

Comments
 (0)