We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f5a06c8 commit b5383c9Copy full SHA for b5383c9
src/pytorch_metric_learning/utils/loss_and_miner_utils.py
@@ -86,6 +86,13 @@ def neg_pairs_from_tuple(indices_tuple):
86
87
def get_all_triplets_indices(labels, ref_labels=None):
88
all_matches, all_diffs = get_matches_and_diffs(labels, ref_labels)
89
+
90
+ if (all_matches.shape[0] * all_matches.shape[1] * all_matches.shape[1]
91
+ < torch.iinfo(torch.int32).max):
92
+ # torch.nonzero is not supported for tensors with more than INT_MAX elements
93
+ triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)
94
+ return torch.where(triplets)
95
96
all_matches, all_diffs = all_matches.bool(), all_diffs.bool()
97
98
# Find anchors with at least a positive and a negative
0 commit comments