diff --git a/CONTENTS.md b/CONTENTS.md
index 6c4bfef3..b0ab5758 100644
--- a/CONTENTS.md
+++ b/CONTENTS.md
@@ -21,12 +21,14 @@
| [**IntraPairVarianceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#intrapairvarianceloss) | [Deep Metric Learning with Tuplet Margin Loss](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yu_Deep_Metric_Learning_With_Tuplet_Margin_Loss_ICCV_2019_paper.pdf)
| [**LargeMarginSoftmaxLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#largemarginsoftmaxloss) | [Large-Margin Softmax Loss for Convolutional Neural Networks](https://arxiv.org/pdf/1612.02295.pdf)
| [**LiftedStructreLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#liftedstructureloss) | [Deep Metric Learning via Lifted Structured Feature Embedding](https://arxiv.org/pdf/1511.06452.pdf)
+| [**ManifoldLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) | [Ensemble Deep Manifold Similarity Learning using Hard Proxies](https://openaccess.thecvf.com/content_CVPR_2019/papers/Aziere_Ensemble_Deep_Manifold_Similarity_Learning_Using_Hard_Proxies_CVPR_2019_paper.pdf)
| [**MarginLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#marginloss) | [Sampling Matters in Deep Embedding Learning](https://arxiv.org/pdf/1706.07567.pdf)
| [**MultiSimilarityLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#multisimilarityloss) | [Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf)
| [**NCALoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ncaloss) | [Neighbourhood Components Analysis](https://www.cs.toronto.edu/~hinton/absps/nca.pdf)
| [**NormalizedSoftmaxLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#normalizedsoftmaxloss) | - [NormFace: L2 Hypersphere Embedding for Face Verification](https://arxiv.org/pdf/1704.06369.pdf)
- [Classification is a Strong Baseline for DeepMetric Learning](https://arxiv.org/pdf/1811.12649.pdf)
| [**NPairsLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#npairsloss) | [Improved Deep Metric Learning with Multi-class N-pair Loss Objective](http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf)
| [**NTXentLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss) | - [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/pdf/1807.03748.pdf)
- [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf)
- [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709)
+| [**P2SGradLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss) | [P2SGrad: Refined Gradients for Optimizing Deep Face Models](https://arxiv.org/abs/1905.02479)
| [**PNPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) | [Rethinking the Optimization of Average Precision: Only Penalizing Negative Instances before Positive Ones is Enough](https://arxiv.org/pdf/2102.04640.pdf)
| [**ProxyAnchorLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyanchorloss) | [Proxy Anchor Loss for Deep Metric Learning](https://arxiv.org/pdf/2003.13911.pdf)
| [**ProxyNCALoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyncaloss) | [No Fuss Distance Metric Learning using Proxies](https://arxiv.org/pdf/1703.07464.pdf)
diff --git a/README.md b/README.md
index 26d568ae..f8eb0c7b 100644
--- a/README.md
+++ b/README.md
@@ -18,13 +18,15 @@
## News
+**June 18**: v2.2.0
+- Added [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) and [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss).
+- Added a `symmetric` flag to [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss).
+- See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.2.0).
+- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0).
+
**April 5**: v2.1.0
- Added [PNPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss)
-- Thanks to contributor [interestingzhuo](https://github.com/interestingzhuo).
-
-**January 29**: v2.0.0
-- Added SelfSupervisedLoss, plus various API improvements. See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.0.0).
-- Thanks to contributor [cwkeam](https://github.com/cwkeam).
+- Thanks you [interestingzhuo](https://github.com/interestingzhuo).
## Documentation
@@ -225,6 +227,7 @@ Thanks to the contributors who made pull requests!
| Contributor | Highlights |
| -- | -- |
+|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
|[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets
- Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons |
|[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss)
- [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss)
- Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/)
- BaseLossWrapper|
|[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer)
- [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss)
- [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester)
- [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) |
@@ -273,6 +276,7 @@ This library contains code that has been adapted and modified from the following
- https://github.com/ronekko/deep_metric_learning
- https://github.com/tjddus9597/Proxy-Anchor-CVPR2020
- http://kaizhao.net/regularface
+- https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts
### Logo
Thanks to [Jeff Musgrave](https://www.designgenius.ca/) for designing the logo.
diff --git a/docs/losses.md b/docs/losses.md
index b0719258..2af1f905 100644
--- a/docs/losses.md
+++ b/docs/losses.md
@@ -545,6 +545,57 @@ losses.LiftedStructureLoss(neg_margin=1, pos_margin=0, **kwargs):
* **loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```.
+## ManifoldLoss
+
+[Ensemble Deep Manifold Similarity Learning using Hard Proxies](https://openaccess.thecvf.com/content_CVPR_2019/papers/Aziere_Ensemble_Deep_Manifold_Similarity_Learning_Using_Hard_Proxies_CVPR_2019_paper.pdf)
+
+```python
+losses.ManifoldLoss(
+ l: int,
+ K: int = 50,
+ lambdaC: float = 1.0,
+ alpha: float = 0.8,
+ margin: float = 5e-4,
+ **kwargs
+ )
+```
+
+**Parameters**
+
+- **l**: embedding size.
+
+- **K**: number of proxies.
+
+- **lambdaC**: regularization weight. Used in the formula `loss = intrinsic_loss + lambdaC*context_loss`.
+ If `lambdaC=0`, then it uses only the intrinsic loss. If `lambdaC=np.inf`, then it uses only the context loss.
+
+- **alpha**: parameter of the Random Walk. Must be in the range `(0,1)`. It specifies the amount of similarity between neighboring nodes.
+
+- **margin**: margin used in the calculation of the loss.
+
+
+Example usage:
+```python
+loss_fn = ManifoldLoss(128)
+
+# use random cluster centers
+loss = loss_fn(embeddings)
+# or specify indices of embeddings to use as cluster centers
+loss = loss_fn(embeddings, indices_tuple=indices)
+```
+
+**Important notes**
+
+`labels`, `ref_emb`, and `ref_labels` are not supported for this loss function.
+
+In addition, `indices_tuple` is **not** for the output of miners. Instead, it is for a list of indices of embeddings to be used as cluster centers.
+
+
+**Default reducer**:
+
+ - This loss returns an **already reduced** loss.
+
+
## MarginLoss
[Sampling Matters in Deep Embedding Learning](https://arxiv.org/pdf/1706.07567.pdf){target=_blank}
```python
@@ -761,6 +812,37 @@ losses.NTXentLoss(temperature=0.07, **kwargs)
* **loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```.
+
+## P2SGradLoss
+[P2SGrad: Refined Gradients for Optimizing Deep Face Models](https://arxiv.org/abs/1905.02479)
+```python
+losses.P2SGradLoss(descriptors_dim, num_classes, **kwargs)
+```
+
+**Parameters**
+
+- **descriptors_dim**: The embedding size.
+
+- **num_classes**: The number of classes in your training dataset.
+
+
+Example usage:
+```python
+loss_fn = P2SGradLoss(128, 10)
+loss = loss_fn(embeddings, labels)
+```
+
+**Important notes**
+
+`indices_tuple`, `ref_emb`, and `ref_labels` are not supported for this loss function.
+
+
+**Default reducer**:
+
+ - This loss returns an **already reduced** loss.
+
+
+
## PNPLoss
[Rethinking the Optimization of Average Precision: Only Penalizing Negative Instances before Positive Ones is Enough](https://arxiv.org/pdf/2102.04640.pdf){target=_blank}
```python
@@ -849,14 +931,31 @@ loss_optimizer.step()
## SelfSupervisedLoss
-A common use case is to have `embeddings` and `ref_emb` be augmented versions of each other. For most losses, you have to create labels to indicate which `embeddings` correspond with which `ref_emb`. `SelfSupervisedLoss` automates this.
+A common use case is to have `embeddings` and `ref_emb` be augmented versions of each other. For most losses, you have to create labels to indicate which `embeddings` correspond with which `ref_emb`.
+
+`SelfSupervisedLoss` is a wrapper that takes care of this by creating labels internally. It assumes that:
+
+- `ref_emb[i]` is an augmented version of `embeddings[i]`.
+- `ref_emb[i]` is the only augmented version of `embeddings[i]` in the batch.
```python
+losses.SelfSupervisedLoss(loss, symmetric=True, **kwargs)
+```
+
+**Parameters**:
+
+* **loss**: The loss function to be wrapped.
+* **symmetric**: If `True`, then the embeddings in both `embeddings` and `ref_emb` are used as anchors. If `False`, then only the embeddings in `embeddings` are used as anchors.
+
+Example usage:
+
+```
loss_fn = losses.TripletMarginLoss()
loss_fn = SelfSupervisedLoss(loss_fn)
loss = loss_fn(embeddings, ref_emb)
```
+
??? "Supported Loss Functions"
- [AngularLoss](losses.md#angularloss)
- [CircleLoss](losses.md#circleloss)
diff --git a/src/pytorch_metric_learning/utils/distributed.py b/src/pytorch_metric_learning/utils/distributed.py
index 93946eed..ce977224 100644
--- a/src/pytorch_metric_learning/utils/distributed.py
+++ b/src/pytorch_metric_learning/utils/distributed.py
@@ -34,7 +34,7 @@ def all_gather_embeddings_and_labels(emb, labels):
return ref_emb, ref_labels
-def gather(emb, labels):
+def gather_bak(emb, labels):
device = emb.device
if labels is not None:
labels = c_f.to_device(labels, device=device)
@@ -45,6 +45,28 @@ def gather(emb, labels):
)
return all_emb, all_labels, labels
+def gather(emb, labels):
+ device = emb.device
+ if labels is not None:
+ labels = c_f.to_device(labels, device=device)
+ # Gather the embeddings from every replica.
+ emb = c_f.to_device(emb, device=device)
+ emb_list = [torch.ones_like(emb) for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(emb_list, emb)
+ # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
+ emb_list[torch.distributed.get_rank()] = emb
+ all_emb = torch.cat(emb_list, dim=0)
+
+ # Gather the labels from every replica.
+ if labels is not None:
+ labels_list = [torch.ones_like(labels) for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(labels_list, labels)
+ # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
+ labels_list[torch.distributed.get_rank()] = labels
+ all_labels = torch.cat(labels_list, dim=0)
+ else:
+ all_labels = None
+ return all_emb, all_labels, labels
def gather_emb_and_ref(emb, labels, ref_emb=None, ref_labels=None):
all_emb, all_labels, labels = gather(emb, labels)
@@ -58,7 +80,17 @@ def gather_emb_and_ref(emb, labels, ref_emb=None, ref_labels=None):
def get_indices_tuple(labels, ref_labels, embeddings=None, ref_emb=None, miner=None):
device = labels.device
- curr_batch_idx = torch.arange(len(labels), device=device)
+
+
+ # curr_batch_idx should be the local batch corresponding idx of ref_batch(global batch)
+ # curr_batch_idx = torch.arange(len(labels), device=device) # this is wrong
+
+ local_bs = len(ref_labels) // torch.distributed.get_world_size()
+ local_b_start_idx = torch.distributed.get_rank() * local_bs
+ curr_batch_idx = torch.arange(local_b_start_idx, (local_b_start_idx + local_bs), device=device)
+
+
+
if miner:
indices_tuple = miner(embeddings, labels, ref_emb, ref_labels)
else:
@@ -66,12 +98,26 @@ def get_indices_tuple(labels, ref_labels, embeddings=None, ref_emb=None, miner=N
return lmu.remove_self_comparisons(indices_tuple, curr_batch_idx, len(ref_labels))
-def gather_enqueue_mask(enqueue_mask, device):
+def gather_enqueue_mask_bak(enqueue_mask, device):
if enqueue_mask is None:
return enqueue_mask
enqueue_mask = c_f.to_device(enqueue_mask, device=device)
return torch.cat([enqueue_mask, all_gather(enqueue_mask)], dim=0)
+def gather_enqueue_mask(enqueue_mask, device):
+ if enqueue_mask is None:
+ return enqueue_mask
+ enqueue_mask = c_f.to_device(enqueue_mask, device=device)
+ # Gather the enqueue_mask from every replica.
+ enqueue_mask_list = [torch.ones_like(enqueue_mask) for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(enqueue_mask_list, enqueue_mask)
+
+ # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
+ enqueue_mask_list[torch.distributed.get_rank()] = enqueue_mask
+
+ return torch.cat(enqueue_mask_list, dim=0)
+
+
def select_ref_or_regular(regular, ref):
return regular if ref is None else ref
@@ -123,12 +169,14 @@ def forward_regular_loss(
if indices_tuple is None:
indices_tuple = get_indices_tuple(labels, all_labels)
loss = self.loss(emb, labels, indices_tuple, all_emb, all_labels)
+
+ return loss
else:
loss = self.loss(
all_emb, all_labels, indices_tuple, all_ref_emb, all_ref_labels
)
- return loss * world_size
+ return loss * world_size
def forward_cross_batch(
self,
@@ -152,9 +200,71 @@ def forward_cross_batch(
emb, labels, ref_emb, ref_labels
)
enqueue_mask = gather_enqueue_mask(enqueue_mask, emb.device)
- loss = self.loss(all_emb, all_labels, indices_tuple, enqueue_mask)
- return loss * world_size
+ # print(f'all_gathered emb size: {all_emb.shape}')
+ # print(f'all_labels emb size: {all_labels.shape}')
+ # print(f'print enqueue_mask after all gather on {torch.distributed.get_rank()}: {enqueue_mask}')
+
+ # loss = self.loss(all_emb, all_labels, indices_tuple, enqueue_mask)
+ loss = self.forward_cross_batch_dist_helper(self.loss, all_emb, all_labels, indices_tuple, enqueue_mask)
+ return loss # unit test has confirmed that this is right.
+
+ def forward_cross_batch_dist_helper(self, loss_inst, embeddings, labels, indices_tuple=None, enqueue_mask=None):
+ if indices_tuple is not None and enqueue_mask is not None:
+ raise ValueError("indices_tuple and enqueue_mask are mutually exclusive")
+ if enqueue_mask is not None:
+ assert len(enqueue_mask) == len(embeddings)
+ else:
+ assert len(embeddings) <= len(loss_inst.embedding_memory)
+ loss_inst.reset_stats()
+ device = embeddings.device
+ labels = c_f.to_device(labels, device=device)
+ loss_inst.embedding_memory = c_f.to_device(
+ loss_inst.embedding_memory, device=device, dtype=embeddings.dtype
+ )
+ loss_inst.label_memory = c_f.to_device(
+ loss_inst.label_memory, device=device, dtype=labels.dtype
+ )
+
+ if enqueue_mask is not None:
+ emb_for_queue = embeddings[enqueue_mask]
+ labels_for_queue = labels[enqueue_mask]
+ embeddings = embeddings[~enqueue_mask]
+ labels = labels[~enqueue_mask]
+ do_remove_self_comparisons = False
+ else:
+ emb_for_queue = embeddings
+ labels_for_queue = labels
+ do_remove_self_comparisons = True
+
+ # ==== DDP specific =====#
+ # get local device emb instead of using all gathered to be efficient
+ local_bs = len(embeddings)//torch.distributed.get_world_size()
+ local_b_start_idx = torch.distributed.get_rank() * local_bs
+ embeddings = embeddings[local_b_start_idx:(local_b_start_idx+local_bs), :]
+ labels = labels[local_b_start_idx:(local_b_start_idx + local_bs)]
+ # ==== end DDP specific ======#
+
+ queue_batch_size = len(emb_for_queue)
+ loss_inst.add_to_memory(emb_for_queue, labels_for_queue, queue_batch_size)
+
+ if not loss_inst.has_been_filled:
+ E_mem = loss_inst.embedding_memory[: loss_inst.queue_idx]
+ L_mem = loss_inst.label_memory[: loss_inst.queue_idx]
+ else:
+ E_mem = loss_inst.embedding_memory
+ L_mem = loss_inst.label_memory
+
+ indices_tuple = loss_inst.create_indices_tuple(
+ embeddings,
+ labels,
+ E_mem,
+ L_mem,
+ indices_tuple,
+ do_remove_self_comparisons,
+ )
+ loss = loss_inst.loss(embeddings, labels, indices_tuple, E_mem, L_mem)
+ return loss
class DistributedMinerWrapper(torch.nn.Module):
def __init__(self, miner, efficient=False):
@@ -172,9 +282,16 @@ def forward(self, emb, labels, ref_emb=None, ref_labels=None):
all_emb, all_labels, all_ref_emb, all_ref_labels, labels = gather_emb_and_ref(
emb, labels, ref_emb, ref_labels
)
+
if self.efficient:
all_labels = select_ref_or_regular(all_labels, all_ref_labels)
all_emb = select_ref_or_regular(all_emb, all_ref_emb)
+
+ # print('in DistributedMinerWrapper: ')
+ # print(f'all_gathered emb size: {all_emb.shape}')
+ # print(f'all_labels emb size: {all_labels.shape}, all labels: {all_labels}')
+ # print(f'labels emb size: {labels.shape}, labels: {labels}')
+
return get_indices_tuple(labels, all_labels, emb, all_emb, self.miner)
else:
return self.miner(all_emb, all_labels, all_ref_emb, all_ref_labels)
diff --git a/tests/utils/test_distributed.py b/tests/utils/test_distributed.py
index baaa71fe..afbc33c1 100644
--- a/tests/utils/test_distributed.py
+++ b/tests/utils/test_distributed.py
@@ -19,6 +19,7 @@
def parameters_are_equal(model1, model2):
output = True
for p1, p2 in zip(model1.parameters(), model2.parameters()):
+ print((p1.data - p2.data).abs().max())
output &= torch.allclose(p1.data, p2.data, rtol=1e-2)
return output
@@ -92,6 +93,7 @@ def single_process_function(
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=lr)
original_model = original_model.to(device)
+ print("assert NOT equal")
assert not parameters_are_equal(original_model, ddp_mp_model.module)
for i in range(iterations):
@@ -122,6 +124,7 @@ def single_process_function(
optimizer.step()
dist.barrier()
+ print("assert equal")
assert parameters_are_equal(original_model, ddp_mp_model.module)
dist.barrier()
cleanup()
@@ -184,7 +187,7 @@ def loss_and_miner_tester(
pass_labels_to_loss_fn=True,
use_xbm_enqueue_mask=False,
):
- torch.manual_seed(75210)
+ # torch.manual_seed(75210)
loss_kwargs = {} if loss_kwargs is None else loss_kwargs
miner_kwargs = {} if miner_kwargs is None else miner_kwargs
if TEST_DEVICE == torch.device("cpu"):
@@ -205,6 +208,7 @@ def loss_and_miner_tester(
original_model = ToyMpModel().type(dtype)
model = ToyMpModel().type(dtype)
model.load_state_dict(original_model.state_dict())
+ print("assert identical")
self.assertTrue(parameters_are_equal(original_model, model))
original_model = original_model.to(TEST_DEVICE)
@@ -331,8 +335,8 @@ def loss_and_miner_tester(
)
def test_distributed_tuple_loss(self):
- for xbm in [False, True]:
- for use_ref in [False, True]:
+ for xbm in [True]:
+ for use_ref in [False]:
for use_xbm_enqueue_mask in [False, True]:
if xbm and use_ref:
continue
diff --git a/tests/utils/test_distributed_xbm_queue.py b/tests/utils/test_distributed_xbm_queue.py
new file mode 100644
index 00000000..d732abfd
--- /dev/null
+++ b/tests/utils/test_distributed_xbm_queue.py
@@ -0,0 +1,298 @@
+import unittest
+from .. import TEST_DEVICE, TEST_DTYPES
+
+
+import os
+import argparse
+import random
+import warnings
+import numpy as np
+
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn as nn
+import torch.optim as optim
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data import Dataset, DataLoader
+
+from pytorch_metric_learning import losses, miners
+from pytorch_metric_learning.utils import distributed as pml_dist
+# def _init_fn(args):
+# def func(worker_id):
+# SEED=args.seed + worker_id
+# torch.manual_seed(SEED)
+# torch.cuda.manual_seed(SEED)
+# np.random.seed(SEED)
+# random.seed(SEED)
+# torch.backends.cudnn.deterministic=True
+# torch.backends.cudnn.benchmark = False
+# return func
+
+def _init_fn(worker_id):
+ SEED = worker_id
+ torch.manual_seed(SEED)
+ torch.cuda.manual_seed(SEED)
+ np.random.seed(SEED)
+ random.seed(SEED)
+ torch.backends.cudnn.deterministic=True
+ torch.backends.cudnn.benchmark = False
+
+def example(local_rank, rank, world_size, args):
+ print(f"current local_rank is {local_rank}, current rank is {rank}")
+
+
+ # ============================DDP specific ===================================#
+ # create default process group
+ dist.init_process_group("nccl")# , init_method='env://')#, rank=rank, world_size=world_size) # rank=rank, world_size=world_size) # not needed since aml launch scripts have the EVN VAR set already?
+
+ if args.seed is not None:
+ random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.backends.cudnn.deterministic = True
+ warnings.warn(
+ "You have chosen to seed training. "
+ "This will turn on the CUDNN deterministic setting, "
+ "which can slow down your training considerably! "
+ "You may see unexpected behavior when restarting "
+ "from checkpoints."
+ )
+
+ # dataset
+ dataset = RandomDataset(args.sample_size)
+
+ print(f"world size is {world_size}")
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
+ dataset,
+ drop_last=True,
+ seed = args.seed#,
+ # num_replicas=world_size, # not needed since aml launch scripts have the EVN VAR set already?
+ # rank=rank
+ )
+ # ============================DDP specific ===================================#
+ args.batch_size //= world_size
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers_per_process, pin_memory=True, sampler=train_sampler, worker_init_fn=_init_fn)
+
+ # create local model
+ model_q = nn.Linear(10, 10).to(local_rank)
+ with torch.no_grad():
+ model_k = nn.Linear(10, 10).to(local_rank)
+
+ #============================DDP specific end ===================================#
+ # construct DDP model
+ model_q = DDP(model_q, device_ids=[local_rank])
+
+ # define loss function and optimizer
+ # miner = miners.MultiSimilarityMiner()
+ # set margins to ensure that no pairs are left out for this example
+ miner = miners.PairMarginMiner(pos_margin=0, neg_margin=100)
+ # miner = pml_dist.DistributedMinerWrapper(miner=miner, efficient=True) ## miner is encoded inside loss in our case, then the miner will have all gathered embs, so no wrapper is needed
+ loss_fn = losses.CrossBatchMemory(loss=losses.NTXentLoss(temperature=0.07), embedding_size=10,
+ memory_size=args.memory_size, miner=miner)
+ loss_fn = pml_dist.DistributedLossWrapper(loss=loss_fn)
+ # ============================DDP specific end ===================================#
+ optimizer = optim.SGD(model_q.parameters(), lr=0.001)
+
+
+ ## train loop
+ # Iterate over the data using the DataLoader
+ for epoch in range(args.num_of_epoch):
+ train_sampler.set_epoch(epoch)
+ for worker_id, index, batch_inputs, batch_outputs in dataloader:
+ print(f"in epoch {epoch}, worker id {worker_id}, index {index}, label {batch_outputs}")
+ # forward pass
+ embeds_Q = model_q(batch_inputs.to(local_rank))
+ labels = batch_outputs.to(local_rank)
+
+ # compute output
+ with torch.no_grad(): # no gradient to keys
+ copy_params(model_q, model_k, m=0.99)
+ # ============================DDP specific ===================================#
+ # shuffle for making use of BN
+ imgK, idx_unshuffle = batch_shuffle_ddp(batch_inputs.to(local_rank))
+ embeds_K = model_k(imgK)
+ embeds_K = batch_unshuffle_ddp(embeds_K, idx_unshuffle)
+ # ============================DDP specific end ===================================#
+
+ # ========================== for debug ==================================#
+ # print(f"input before shuffle on rank {rank}: {batch_inputs}")
+ # print(f"input after shuffle on rank {rank}: {imgK}")
+ gpu_idx = torch.distributed.get_rank()
+
+ num_gpus = world_size
+ idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
+
+ # ========================== end for debug ==================================#
+ ## same as original MOCO: same image augmented with the same label
+ all_enc = torch.cat([embeds_Q, embeds_K], dim=0)
+ labels, enqueue_mask = create_labels(embeds_Q.size(0), labels, local_rank)
+ # # ========================= debug =======================#
+ # print("======================== all gathering ====================")
+ # world_size = torch.distributed.get_world_size()
+ # print(f"world_size is {world_size}")
+ # all_enc, labels, _, _, _ = pml_dist.gather_emb_and_ref(
+ # all_enc, labels
+ # )
+ # enqueue_mask = pml_dist.gather_enqueue_mask(enqueue_mask, all_enc.device)
+ # # ========================= end of debug =======================#
+ loss = loss_fn(all_enc, labels,
+ enqueue_mask=enqueue_mask) # miner will be used in loss_fn if initialized with miner
+
+ # compute gradient and do SGD step
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+
+ # check pair size
+ print(f'label on mem bank on rank{rank}: {loss_fn.loss.label_memory}') # here we should see the same labels on mem queue on different workers
+ print(f"bs is {batch_inputs.shape[0]}")
+ print("num of pos pairs: ", loss_fn.loss.miner.num_pos_pairs)
+ print("num of neg pairs: ", loss_fn.loss.miner.num_neg_pairs)
+ print(f"epoch {epoch}, loss is {loss}")
+ # print weights
+ for param in model_q.parameters():
+ print(param.data)
+
+ dist.destroy_process_group()
+
+
+def copy_params(encQ, encK, m=None):
+ if m is None:
+ for param_q, param_k in zip(encQ.parameters(), encK.parameters()):
+ param_k.data.copy_(param_q.data) # initialize
+ param_k.requires_grad = False # not update by gradient
+ else:
+ for param_q, param_k in zip(encQ.parameters(), encK.parameters()):
+ param_k.data = param_k.data * m + param_q.data * (1.0 - m)
+
+class RandomDataset(Dataset):
+ def __init__(self, sample_size):
+ self.samples = torch.randn(sample_size, 10)
+ # self.samples = torch.ones_like(self.samples)
+ self.labels = torch.arange(sample_size)
+ print(f'data set labels: {self.labels}')
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __getitem__(self, index):
+ worker_info = torch.utils.data.get_worker_info()
+ if not worker_info:
+ raise NotImplementedError('Not implemented for num_workers=0')
+ print(f"in worker {worker_info.id}, index {index}, label is {self.labels[index]}")
+ return worker_info.id, index, self.samples[index], self.labels[index]
+def create_labels(num_pos_pairs, labels, device):
+ # create labels that indicate what the positive pairs are
+ labels = torch.cat((labels, labels)).to(device)
+
+ # we want to enqueue the output of encK, which is the 2nd half of the batch
+ enqueue_mask = torch.zeros(len(labels)).bool()
+ enqueue_mask[num_pos_pairs:] = True
+ return labels, enqueue_mask
+@torch.no_grad()
+def batch_shuffle_ddp(x):
+ """
+ Batch shuffle, for making use of BatchNorm.
+ *** Only support DistributedDataParallel (DDP) model. ***
+ *** x should be on local_rank device
+ """
+ # gather from all gpus
+ batch_size_this = x.shape[0]
+ x_gather = concat_all_gather(x)
+ batch_size_all = x_gather.shape[0]
+
+ num_gpus = batch_size_all // batch_size_this
+
+ # random shuffle index
+ idx_shuffle = torch.randperm(batch_size_all).to(x.device)# .cuda()
+
+ # broadcast to all gpus
+ torch.distributed.broadcast(idx_shuffle, src=0)
+
+ # index for restoring
+ idx_unshuffle = torch.argsort(idx_shuffle)
+
+ # shuffled index for this gpu
+ gpu_idx = torch.distributed.get_rank()
+ idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
+
+ return x_gather[idx_this], idx_unshuffle
+
+@torch.no_grad()
+def batch_unshuffle_ddp(x, idx_unshuffle):
+ """
+ Undo batch shuffle.
+ *** Only support DistributedDataParallel (DDP) model. ***
+ """
+ # gather from all gpus
+ batch_size_this = x.shape[0]
+ x_gather = concat_all_gather(x)
+ batch_size_all = x_gather.shape[0]
+
+ num_gpus = batch_size_all // batch_size_this
+
+ # restored index for this gpu
+ gpu_idx = torch.distributed.get_rank()
+ idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
+
+ return x_gather[idx_this]
+# utils
+@torch.no_grad()
+def concat_all_gather(tensor):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ tensors_gather = [
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
+ ]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+class TestDistributedXbmQueue(unittest.TestCase):
+ def main(self):
+ # Create the argument parser
+ parser = argparse.ArgumentParser(description='Train a metric learning model for MRI protocol classification.')
+
+ # Add arguments to the parser
+ parser.add_argument('--sample_size', type=int,
+ default=32,
+ help='batch_size')
+ parser.add_argument('--batch_size', type=int,
+ default=32,
+ help='global batch_size')
+ parser.add_argument('--seed', type=int,
+ default=None,
+ help='world_size')
+ parser.add_argument('--memory_size', type=int,
+ default=32,
+ help='memory_size')
+ parser.add_argument('--num_workers_per_process', type=int,
+ default=8,
+ help='num_workers')
+ parser.add_argument('--num_of_epoch', type=int,
+ default=5,
+ help='num_of_epoch')
+
+ args = parser.parse_args()
+ #########################################################
+ # Get PyTorch environment variables for distributed training.
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+
+ # here, why no use of mp.spawn? ==> aml submit script is supposed to take care of the EVN VARs:
+ # MASTER_ADDR, MASTER_PORT, NODE_RANK, WORLD_SIZE
+ # RANK, LOCAL_RANK
+ os.environ['NCCL_DEBUG'] = 'INFO'
+ print(f"args.batch_size is {args.batch_size}, world_size is {world_size}")
+ assert (args.batch_size % world_size) == 0
+ example(local_rank, rank, world_size, args)
+
+
+if __name__=="__main__":
+ unittest.main()
\ No newline at end of file