Skip to content

Commit 5e5319d

Browse files
Merge pull request #692 from KevinMusgrave/dev
v2.5.0
2 parents 3a14f82 + ef65345 commit 5e5319d

File tree

6 files changed

+90
-17
lines changed

6 files changed

+90
-17
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.4.1"
1+
__version__ = "2.5.0"

src/pytorch_metric_learning/losses/manifold_loss.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,13 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
9090
if self.lambdaC != np.inf:
9191
F = F[:N, N:]
9292
loss_int = F - F[torch.arange(N), meta_classes].view(-1, 1) + self.margin
93-
loss_int[
94-
torch.arange(N), meta_classes
95-
] = -np.inf # This way avoid numerical cancellation happening # NoQA
93+
loss_int[torch.arange(N), meta_classes] = (
94+
-np.inf
95+
) # This way avoid numerical cancellation happening # NoQA
9696
# instead with subtraction of margin term # NoQA
97-
loss_int[
98-
loss_int < 0
99-
] = -np.inf # This way no loss for positive correlation with own proxy
97+
loss_int[loss_int < 0] = (
98+
-np.inf
99+
) # This way no loss for positive correlation with own proxy
100100

101101
loss_int = torch.exp(loss_int)
102102
loss_int = torch.log(1 + torch.sum(loss_int, dim=1))
@@ -106,9 +106,9 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
106106
F_e, F_p.unsqueeze(1), dim=-1
107107
).t()
108108
loss_ctx += -loss_ctx[torch.arange(N), meta_classes].view(-1, 1) + self.margin
109-
loss_ctx[
110-
torch.arange(N), meta_classes
111-
] = -np.inf # This way avoid numerical cancellation happening # NoQA
109+
loss_ctx[torch.arange(N), meta_classes] = (
110+
-np.inf
111+
) # This way avoid numerical cancellation happening # NoQA
112112
# instead with subtraction of margin term # NoQA
113113
loss_ctx[loss_ctx < 0] = -np.inf
114114

src/pytorch_metric_learning/testers/base_tester.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,10 @@ def test(
306306
query_split_name,
307307
reference_split_names,
308308
)
309-
self.end_of_testing_hook(self) if self.end_of_testing_hook else c_f.LOGGER.info(
310-
self.all_accuracies
309+
(
310+
self.end_of_testing_hook(self)
311+
if self.end_of_testing_hook
312+
else c_f.LOGGER.info(self.all_accuracies)
311313
)
312314
del self.embeddings_and_labels
313315
return self.all_accuracies

src/pytorch_metric_learning/utils/loss_and_miner_utils.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,57 @@ def neg_pairs_from_tuple(indices_tuple):
8585

8686

8787
def get_all_triplets_indices(labels, ref_labels=None):
88-
matches, diffs = get_matches_and_diffs(labels, ref_labels)
89-
triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
88+
all_matches, all_diffs = get_matches_and_diffs(labels, ref_labels)
89+
90+
if (
91+
all_matches.shape[0] * all_matches.shape[1] * all_matches.shape[1]
92+
< torch.iinfo(torch.int32).max
93+
):
94+
# torch.nonzero is not supported for tensors with more than INT_MAX elements
95+
return get_all_triplets_indices_vectorized_method(all_matches, all_diffs)
96+
97+
return get_all_triplets_indices_loop_method(labels, all_matches, all_diffs)
98+
99+
100+
def get_all_triplets_indices_vectorized_method(all_matches, all_diffs):
101+
triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)
90102
return torch.where(triplets)
91103

92104

105+
def get_all_triplets_indices_loop_method(labels, all_matches, all_diffs):
106+
all_matches, all_diffs = all_matches.bool(), all_diffs.bool()
107+
108+
# Find anchors with at least a positive and a negative
109+
indices = torch.arange(0, len(labels), device=labels.device)
110+
indices = indices[all_matches.any(dim=1) & all_diffs.any(dim=1)]
111+
112+
# No triplets found
113+
if len(indices) == 0:
114+
return (
115+
torch.tensor([], device=labels.device, dtype=labels.dtype),
116+
torch.tensor([], device=labels.device, dtype=labels.dtype),
117+
torch.tensor([], device=labels.device, dtype=labels.dtype),
118+
)
119+
120+
# Compute all triplets
121+
anchors = []
122+
positives = []
123+
negatives = []
124+
for i in indices:
125+
matches = all_matches[i].nonzero(as_tuple=False).squeeze(1)
126+
diffs = all_diffs[i].nonzero(as_tuple=False).squeeze(1)
127+
nd = len(diffs)
128+
nm = len(matches)
129+
matches = matches.repeat_interleave(nd)
130+
diffs = diffs.repeat(nm)
131+
anchors.append(
132+
torch.full((len(matches),), i, dtype=labels.dtype, device=labels.device)
133+
)
134+
positives.append(matches)
135+
negatives.append(diffs)
136+
return torch.cat(anchors), torch.cat(positives), torch.cat(negatives)
137+
138+
93139
# sample triplets, with a weighted distribution if weights is specified.
94140
def get_random_triplet_indices(
95141
labels, ref_labels=None, t_per_anchor=None, weights=None

tests/utils/test_calculate_accuracies.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,11 @@ def test_accuracy_calculator(self):
6767
"query_labels": query_labels,
6868
"label_counts": label_counts,
6969
"knn_labels": knn_labels,
70-
"not_lone_query_mask": torch.ones(6, dtype=torch.bool)
71-
if i == 0
72-
else torch.zeros(6, dtype=torch.bool),
70+
"not_lone_query_mask": (
71+
torch.ones(6, dtype=torch.bool)
72+
if i == 0
73+
else torch.zeros(6, dtype=torch.bool)
74+
),
7375
}
7476

7577
function_dict = AC.get_function_dict()

tests/utils/test_loss_and_miner_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,29 @@ def test_remove_self_comparisons_small_ref(self):
291291
self.assertTrue(torch.equal(a1, correct_a1))
292292
self.assertTrue(torch.equal(p, correct_p))
293293

294+
def test_get_all_triplets_indices(self):
295+
torch.manual_seed(920)
296+
for dtype in TEST_DTYPES:
297+
for batch_size in [32, 256, 512]:
298+
for ref_labels in [None, torch.randint(0, 5, size=(batch_size // 2,))]:
299+
labels = torch.randint(0, 5, size=(batch_size,))
300+
301+
a, p, n = lmu.get_all_triplets_indices(labels, ref_labels)
302+
matches, diffs = lmu.get_matches_and_diffs(labels, ref_labels)
303+
304+
a2, p2, n2 = lmu.get_all_triplets_indices_vectorized_method(
305+
matches, diffs
306+
)
307+
a3, p3, n3 = lmu.get_all_triplets_indices_loop_method(
308+
labels, matches, diffs
309+
)
310+
self.assertTrue(
311+
(a == a2).all() and (p == p2).all() and (n == n2).all()
312+
)
313+
self.assertTrue(
314+
(a == a3).all() and (p == p3).all() and (n == n3).all()
315+
)
316+
294317

295318
if __name__ == "__main__":
296319
unittest.main()

0 commit comments

Comments
 (0)