Skip to content

Commit 1f62a2f

Browse files
Merge pull request #677 from KevinMusgrave/dev
v2.4.0
2 parents a363750 + 649e110 commit 1f62a2f

15 files changed

+844
-16
lines changed

.github/workflows/base_test_workflow.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@ jobs:
1313
runs-on: ubuntu-latest
1414
strategy:
1515
matrix:
16-
python-version: [3.8]
17-
pytorch-version: [1.6, 1.11]
18-
torchvision-version: [0.7.0, 0.12.0]
19-
with-collect-stats: [false]
20-
exclude:
21-
- pytorch-version: 1.6
22-
torchvision-version: 0.12.0
23-
- pytorch-version: 1.11
24-
torchvision-version: 0.7.0
16+
include:
17+
- python-version: 3.8
18+
pytorch-version: 1.6
19+
torchvision-version: 0.7
20+
- python-version: 3.9
21+
pytorch-version: 2.1
22+
torchvision-version: 0.16
2523

2624
steps:
2725
- uses: actions/checkout@v2
@@ -34,6 +32,8 @@ jobs:
3432
pip install .[with-hooks-cpu]
3533
pip install torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
3634
pip install --upgrade protobuf==3.20.1
35+
pip install six
36+
pip install packaging
3737
- name: Run unit tests
3838
run: |
3939
TEST_DTYPES=float32,float64 TEST_DEVICE=cpu WITH_COLLECT_STATS=${{ matrix.with-collect-stats }} python -m unittest discover -t . -s tests/${{ inputs.module-to-test }}

docs/losses.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,19 @@ The queue can be cleared like this:
345345
loss_fn.reset_queue()
346346
```
347347

348+
## DynamicSoftMarginLoss
349+
[Learning Local Descriptors With a CDF-Based Dynamic Soft Margin](https://openaccess.thecvf.com/content_ICCV_2019/papers/Zhang_Learning_Local_Descriptors_With_a_CDF-Based_Dynamic_Soft_Margin_ICCV_2019_paper.pdf)
350+
```python
351+
losses.DynamicSoftMarginLoss(min_val=-2.0, num_bins=10, momentum=0.01, **kwargs)
352+
```
353+
354+
**Parameters**:
355+
356+
* **min_val**: minimum significative value for `d_pos - d_neg`
357+
* **num_bins**: number of equally spaced bins for the partition of the interval `[min_val, ∞]`
358+
* **momentum**: weight assigned to the histogram computed from the current batch
359+
360+
348361
## FastAPLoss
349362
[Deep Metric Learning to Rank](http://openaccess.thecvf.com/content_CVPR_2019/papers/Cakir_Deep_Metric_Learning_to_Rank_CVPR_2019_paper.pdf){target=_blank}
350363

@@ -993,6 +1006,20 @@ loss_optimizer.step()
9931006

9941007
* **loss**: The loss per element in the batch, that results in a non zero exponent in the cross entropy expression. Reduction type is ```"element"```.
9951008

1009+
## RankedListLoss
1010+
[Ranked List Loss for Deep Metric Learning](https://arxiv.org/abs/1903.03238)
1011+
```python
1012+
losses.RankedListLoss(margin, Tn, imbalance=0.5, alpha=None, Tp=0, **kwargs)
1013+
```
1014+
1015+
**Parameters**:
1016+
1017+
* **margin** (float): margin between positive and negative set
1018+
* **imbalance** (float): tradeoff between positive and negative sets. As the name suggests this takes into account
1019+
the imbalance between positive and negative samples in the dataset
1020+
* **alpha** (float): smallest distance between negative points
1021+
* **Tp & Tn** (float): temperatures for, respectively, positive and negative pairs weighting.
1022+
9961023

9971024
## SelfSupervisedLoss
9981025

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.3.0"
1+
__version__ = "2.4.0"

src/pytorch_metric_learning/losses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .contrastive_loss import ContrastiveLoss
77
from .cosface_loss import CosFaceLoss
88
from .cross_batch_memory import CrossBatchMemory
9+
from .dynamic_soft_margin_loss import DynamicSoftMarginLoss
910
from .fast_ap_loss import FastAPLoss
1011
from .generic_pair_loss import GenericPairLoss
1112
from .histogram_loss import HistogramLoss
@@ -26,6 +27,7 @@
2627
from .pnp_loss import PNPLoss
2728
from .proxy_anchor_loss import ProxyAnchorLoss
2829
from .proxy_losses import ProxyNCALoss
30+
from .ranked_list_loss import RankedListLoss
2931
from .self_supervised_loss import SelfSupervisedLoss
3032
from .signal_to_noise_ratio_losses import SignalToNoiseRatioContrastiveLoss
3133
from .soft_triple_loss import SoftTripleLoss
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import numpy as np
2+
import torch
3+
4+
from ..distances import LpDistance
5+
from ..utils import common_functions as c_f
6+
from ..utils import loss_and_miner_utils as lmu
7+
from .base_metric_loss_function import BaseMetricLossFunction
8+
9+
10+
def find_hard_negatives(dmat):
11+
"""
12+
a = A * P'
13+
A: N * ndim
14+
P: N * ndim
15+
16+
a1p1 a1p2 a1p3 a1p4 ...
17+
a2p1 a2p2 a2p3 a2p4 ...
18+
a3p1 a3p2 a3p3 a3p4 ...
19+
a4p1 a4p2 a4p3 a4p4 ...
20+
... ... ... ...
21+
"""
22+
23+
pos = dmat.diag()
24+
dmat.fill_diagonal_(np.inf)
25+
26+
min_a, _ = torch.min(dmat, dim=0)
27+
min_p, _ = torch.min(dmat, dim=1)
28+
neg = torch.min(min_a, min_p)
29+
return pos, neg
30+
31+
32+
class DynamicSoftMarginLoss(BaseMetricLossFunction):
33+
r"""Loss function with dynamical margin parameter introduced in https://openaccess.thecvf.com/content_ICCV_2019/papers/Zhang_Learning_Local_Descriptors_With_a_CDF-Based_Dynamic_Soft_Margin_ICCV_2019_paper.pdf
34+
35+
Args:
36+
min_val: minimum significative value for `d_pos - d_neg`
37+
num_bins: number of equally spaced bins for the partition of the interval [min_val, :math:`+\infty`]
38+
momentum: weight assigned to the histogram computed from the current batch
39+
"""
40+
41+
def __init__(self, min_val=-2.0, num_bins=10, momentum=0.01, **kwargs):
42+
super().__init__(**kwargs)
43+
c_f.assert_distance_type(self, LpDistance, normalize_embeddings=True, p=2)
44+
self.min_val = min_val
45+
self.num_bins = int(num_bins)
46+
self.delta = 2 * abs(min_val) / num_bins
47+
self.momentum = momentum
48+
self.hist_ = torch.zeros((num_bins,))
49+
self.add_to_recordable_attributes(list_of_names=["num_bins"], is_stat=False)
50+
51+
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
52+
self.hist_ = c_f.to_device(
53+
self.hist_, tensor=embeddings, dtype=embeddings.dtype
54+
)
55+
56+
if labels is None:
57+
loss = self.compute_loss_without_labels(
58+
embeddings, labels, indices_tuple, ref_emb, ref_labels
59+
)
60+
else:
61+
loss = self.compute_loss_with_labels(
62+
embeddings, labels, indices_tuple, ref_emb, ref_labels
63+
)
64+
65+
if len(loss) == 0:
66+
return self.zero_losses()
67+
68+
self.update_histogram(loss)
69+
loss = self.weigh_loss(loss)
70+
loss = loss.mean()
71+
return {
72+
"loss": {
73+
"losses": loss,
74+
"indices": None,
75+
"reduction_type": "already_reduced",
76+
}
77+
}
78+
79+
def compute_loss_without_labels(
80+
self, embeddings, labels, indices_tuple, ref_emb, ref_labels
81+
):
82+
mat = self.distance(embeddings, ref_emb)
83+
r, c = mat.size()
84+
85+
d_pos = torch.zeros(max(r, c))
86+
d_pos = c_f.to_device(d_pos, tensor=embeddings, dtype=embeddings.dtype)
87+
d_pos[: min(r, c)] = mat.diag()
88+
mat.fill_diagonal_(np.inf)
89+
90+
min_a, min_p = torch.zeros(max(r, c)), torch.zeros(
91+
max(r, c)
92+
) # Check for unequal number of anchors and positives
93+
min_a = c_f.to_device(min_a, tensor=embeddings, dtype=embeddings.dtype)
94+
min_p = c_f.to_device(min_p, tensor=embeddings, dtype=embeddings.dtype)
95+
min_a[:c], _ = torch.min(mat, dim=0)
96+
min_p[:r], _ = torch.min(mat, dim=1)
97+
98+
d_neg = torch.min(min_a, min_p)
99+
return d_pos - d_neg
100+
101+
def compute_loss_with_labels(
102+
self, embeddings, labels, indices_tuple, ref_emb, ref_labels
103+
):
104+
anchor_idx, positive_idx, negative_idx = lmu.convert_to_triplets(
105+
indices_tuple, labels, ref_labels, t_per_anchor="all"
106+
) # Use all instead of t_per_anchor=1 to be deterministic
107+
mat = self.distance(embeddings, ref_emb)
108+
d_pos, d_neg = mat[anchor_idx, positive_idx], mat[anchor_idx, negative_idx]
109+
return d_pos - d_neg
110+
111+
def update_histogram(self, data):
112+
idx, alpha = torch.floor((data - self.min_val) / self.delta).to(
113+
dtype=torch.long
114+
), torch.frac((data - self.min_val) / self.delta)
115+
momentum = self.momentum if self.hist_.sum() != 0 else 1.0
116+
self.hist_ = torch.scatter_add(
117+
(1.0 - momentum) * self.hist_, 0, idx, momentum * (1 - alpha)
118+
)
119+
self.hist_ = torch.scatter_add(self.hist_, 0, idx + 1, momentum * alpha)
120+
self.hist_ /= self.hist_.sum()
121+
122+
def weigh_loss(self, data):
123+
CDF = torch.cumsum(self.hist_, 0)
124+
idx = torch.floor((data - self.min_val) / self.delta).to(dtype=torch.long)
125+
return CDF[idx] * data

src/pytorch_metric_learning/losses/histogram_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, n_bins: int = None, delta: float = None, **kwargs):
2525
n_bins = 100
2626

2727
self.delta = delta if delta is not None else 2 / n_bins
28-
self.add_to_recordable_attributes(name="delta", is_stat=True)
28+
self.add_to_recordable_attributes(name="delta", is_stat=False)
2929

3030
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
3131
c_f.labels_or_indices_tuple_required(labels, indices_tuple)

src/pytorch_metric_learning/losses/pnp_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
6868
else:
6969
raise Exception(f"variant <{self.variant}> not available!")
7070

71-
loss = torch.sum(sim_all_rk * I_pos, dim=-1) / N_pos.reshape(-1)
72-
loss = torch.sum(loss) / N
71+
loss = torch.sum(sim_all_rk * I_pos, dim=-1)[safe_N] / N_pos[safe_N].reshape(-1)
72+
loss = torch.sum(loss) / torch.sum(safe_N)
7373
if self.variant == "Dq":
7474
loss = 1 - loss
7575

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import warnings
2+
3+
import torch
4+
5+
from ..distances import LpDistance
6+
from ..utils import common_functions as c_f
7+
from .base_metric_loss_function import BaseMetricLossFunction
8+
9+
10+
class RankedListLoss(BaseMetricLossFunction):
11+
r"""Ranked List Loss described in https://arxiv.org/abs/1903.03238
12+
Default parameters correspond to RLL-Simpler, preferred for exploratory analysis.
13+
14+
Args:
15+
* margin (float): margin between positive and negative set
16+
* imbalance (float): tradeoff between positive and negative sets. As the name suggests this takes into account
17+
the imbalance between positive and negative samples in the dataset
18+
* alpha (float): smallest distance between negative points
19+
* Tp & Tn (float): temperatures for, respectively, positive and negative pairs weighting.
20+
"""
21+
22+
def __init__(self, margin, Tn, imbalance=0.5, alpha=None, Tp=0, **kwargs):
23+
super().__init__(**kwargs)
24+
25+
self.margin = margin
26+
27+
assert 0 <= imbalance <= 1, "Imbalance must be between 0 and 1"
28+
self.imbalance = imbalance
29+
30+
if alpha is not None:
31+
self.alpha = alpha
32+
else:
33+
self.alpha = 1 + margin / 2
34+
35+
if Tp > 5 or Tn > 5:
36+
warnings.warn(
37+
"Values of Tp or Tn are too high. Too large temperature values may lead to overflow."
38+
)
39+
40+
self.Tp = Tp
41+
self.Tn = Tn
42+
self.add_to_recordable_attributes(
43+
list_of_names=["imbalance", "alpha", "margin", "Tp", "Tn"], is_stat=False
44+
)
45+
46+
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
47+
c_f.labels_required(labels)
48+
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
49+
c_f.indices_tuple_not_supported(indices_tuple)
50+
51+
mat = self.distance(embeddings, embeddings)
52+
# mat.fill_diagonal_(0)
53+
mat = mat - mat * torch.eye(len(mat), device=embeddings.device)
54+
mat = c_f.to_device(mat, device=embeddings.device, dtype=embeddings.dtype)
55+
y = labels.unsqueeze(1) == labels.unsqueeze(0)
56+
57+
P_star = torch.zeros_like(mat)
58+
N_star = torch.zeros_like(mat)
59+
w_p = torch.zeros_like(mat)
60+
w_n = torch.zeros_like(mat)
61+
62+
N_star[(~y) * (mat < self.alpha)] = mat[(~y) * (mat < self.alpha)]
63+
y.fill_diagonal_(False)
64+
P_star[y * (mat > (self.alpha - self.margin))] = mat[
65+
y * (mat > (self.alpha - self.margin))
66+
]
67+
68+
w_p[P_star > 0] = torch.exp(
69+
self.Tp * (P_star[P_star > 0] - (self.alpha - self.margin))
70+
)
71+
w_n[N_star > 0] = torch.exp(self.Tn * (self.alpha - N_star[N_star > 0]))
72+
73+
loss_P = torch.sum(
74+
w_p * (P_star - (self.alpha - self.margin)), dim=1
75+
) / torch.sum(w_p + 1e-5, dim=1)
76+
77+
loss_N = torch.sum(w_n * (self.alpha - N_star), dim=1) / torch.sum(
78+
w_n + 1e-5, dim=1
79+
)
80+
81+
# with torch.no_grad():
82+
# loss_P[loss_P.isnan()] = 0
83+
# loss_N[loss_N.isnan()] = 0
84+
85+
loss_RLL = (1 - self.imbalance) * loss_P + self.imbalance * loss_N
86+
87+
return {
88+
"loss": {
89+
"losses": loss_RLL,
90+
"indices": c_f.torch_arange_from_size(loss_RLL),
91+
"reduction_type": "element",
92+
}
93+
}
94+
95+
def get_default_distance(self):
96+
return LpDistance()

0 commit comments

Comments
 (0)