diff --git a/CONTENTS.md b/CONTENTS.md index 6c4bfef3..b0ab5758 100644 --- a/CONTENTS.md +++ b/CONTENTS.md @@ -21,12 +21,14 @@ | [**IntraPairVarianceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#intrapairvarianceloss) | [Deep Metric Learning with Tuplet Margin Loss](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yu_Deep_Metric_Learning_With_Tuplet_Margin_Loss_ICCV_2019_paper.pdf) | [**LargeMarginSoftmaxLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#largemarginsoftmaxloss) | [Large-Margin Softmax Loss for Convolutional Neural Networks](https://arxiv.org/pdf/1612.02295.pdf) | [**LiftedStructreLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#liftedstructureloss) | [Deep Metric Learning via Lifted Structured Feature Embedding](https://arxiv.org/pdf/1511.06452.pdf) +| [**ManifoldLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) | [Ensemble Deep Manifold Similarity Learning using Hard Proxies](https://openaccess.thecvf.com/content_CVPR_2019/papers/Aziere_Ensemble_Deep_Manifold_Similarity_Learning_Using_Hard_Proxies_CVPR_2019_paper.pdf) | [**MarginLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#marginloss) | [Sampling Matters in Deep Embedding Learning](https://arxiv.org/pdf/1706.07567.pdf) | [**MultiSimilarityLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#multisimilarityloss) | [Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf) | [**NCALoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ncaloss) | [Neighbourhood Components Analysis](https://www.cs.toronto.edu/~hinton/absps/nca.pdf) | [**NormalizedSoftmaxLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#normalizedsoftmaxloss) | - [NormFace: L2 Hypersphere Embedding for Face Verification](https://arxiv.org/pdf/1704.06369.pdf)
- [Classification is a Strong Baseline for DeepMetric Learning](https://arxiv.org/pdf/1811.12649.pdf) | [**NPairsLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#npairsloss) | [Improved Deep Metric Learning with Multi-class N-pair Loss Objective](http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf) | [**NTXentLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss) | - [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/pdf/1807.03748.pdf)
- [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf)
- [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709) +| [**P2SGradLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss) | [P2SGrad: Refined Gradients for Optimizing Deep Face Models](https://arxiv.org/abs/1905.02479) | [**PNPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) | [Rethinking the Optimization of Average Precision: Only Penalizing Negative Instances before Positive Ones is Enough](https://arxiv.org/pdf/2102.04640.pdf) | [**ProxyAnchorLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyanchorloss) | [Proxy Anchor Loss for Deep Metric Learning](https://arxiv.org/pdf/2003.13911.pdf) | [**ProxyNCALoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyncaloss) | [No Fuss Distance Metric Learning using Proxies](https://arxiv.org/pdf/1703.07464.pdf) diff --git a/README.md b/README.md index 26d568ae..f8eb0c7b 100644 --- a/README.md +++ b/README.md @@ -18,13 +18,15 @@ ## News +**June 18**: v2.2.0 +- Added [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) and [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss). +- Added a `symmetric` flag to [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss). +- See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.2.0). +- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0). + **April 5**: v2.1.0 - Added [PNPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) -- Thanks to contributor [interestingzhuo](https://github.com/interestingzhuo). - -**January 29**: v2.0.0 -- Added SelfSupervisedLoss, plus various API improvements. See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.0.0). -- Thanks to contributor [cwkeam](https://github.com/cwkeam). +- Thanks you [interestingzhuo](https://github.com/interestingzhuo). ## Documentation @@ -225,6 +227,7 @@ Thanks to the contributors who made pull requests! | Contributor | Highlights | | -- | -- | +|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss) |[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets
- Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons | |[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss)
- [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss)
- Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/)
- BaseLossWrapper| |[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer)
- [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss)
- [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester)
- [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) | @@ -273,6 +276,7 @@ This library contains code that has been adapted and modified from the following - https://github.com/ronekko/deep_metric_learning - https://github.com/tjddus9597/Proxy-Anchor-CVPR2020 - http://kaizhao.net/regularface +- https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts ### Logo Thanks to [Jeff Musgrave](https://www.designgenius.ca/) for designing the logo. diff --git a/docs/losses.md b/docs/losses.md index b0719258..2af1f905 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -545,6 +545,57 @@ losses.LiftedStructureLoss(neg_margin=1, pos_margin=0, **kwargs): * **loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```. +## ManifoldLoss + +[Ensemble Deep Manifold Similarity Learning using Hard Proxies](https://openaccess.thecvf.com/content_CVPR_2019/papers/Aziere_Ensemble_Deep_Manifold_Similarity_Learning_Using_Hard_Proxies_CVPR_2019_paper.pdf) + +```python +losses.ManifoldLoss( + l: int, + K: int = 50, + lambdaC: float = 1.0, + alpha: float = 0.8, + margin: float = 5e-4, + **kwargs + ) +``` + +**Parameters** + +- **l**: embedding size. + +- **K**: number of proxies. + +- **lambdaC**: regularization weight. Used in the formula `loss = intrinsic_loss + lambdaC*context_loss`. + If `lambdaC=0`, then it uses only the intrinsic loss. If `lambdaC=np.inf`, then it uses only the context loss. + +- **alpha**: parameter of the Random Walk. Must be in the range `(0,1)`. It specifies the amount of similarity between neighboring nodes. + +- **margin**: margin used in the calculation of the loss. + + +Example usage: +```python +loss_fn = ManifoldLoss(128) + +# use random cluster centers +loss = loss_fn(embeddings) +# or specify indices of embeddings to use as cluster centers +loss = loss_fn(embeddings, indices_tuple=indices) +``` + +**Important notes** + +`labels`, `ref_emb`, and `ref_labels` are not supported for this loss function. + +In addition, `indices_tuple` is **not** for the output of miners. Instead, it is for a list of indices of embeddings to be used as cluster centers. + + +**Default reducer**: + + - This loss returns an **already reduced** loss. + + ## MarginLoss [Sampling Matters in Deep Embedding Learning](https://arxiv.org/pdf/1706.07567.pdf){target=_blank} ```python @@ -761,6 +812,37 @@ losses.NTXentLoss(temperature=0.07, **kwargs) * **loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```. + +## P2SGradLoss +[P2SGrad: Refined Gradients for Optimizing Deep Face Models](https://arxiv.org/abs/1905.02479) +```python +losses.P2SGradLoss(descriptors_dim, num_classes, **kwargs) +``` + +**Parameters** + +- **descriptors_dim**: The embedding size. + +- **num_classes**: The number of classes in your training dataset. + + +Example usage: +```python +loss_fn = P2SGradLoss(128, 10) +loss = loss_fn(embeddings, labels) +``` + +**Important notes** + +`indices_tuple`, `ref_emb`, and `ref_labels` are not supported for this loss function. + + +**Default reducer**: + + - This loss returns an **already reduced** loss. + + + ## PNPLoss [Rethinking the Optimization of Average Precision: Only Penalizing Negative Instances before Positive Ones is Enough](https://arxiv.org/pdf/2102.04640.pdf){target=_blank} ```python @@ -849,14 +931,31 @@ loss_optimizer.step() ## SelfSupervisedLoss -A common use case is to have `embeddings` and `ref_emb` be augmented versions of each other. For most losses, you have to create labels to indicate which `embeddings` correspond with which `ref_emb`. `SelfSupervisedLoss` automates this. +A common use case is to have `embeddings` and `ref_emb` be augmented versions of each other. For most losses, you have to create labels to indicate which `embeddings` correspond with which `ref_emb`. + +`SelfSupervisedLoss` is a wrapper that takes care of this by creating labels internally. It assumes that: + +- `ref_emb[i]` is an augmented version of `embeddings[i]`. +- `ref_emb[i]` is the only augmented version of `embeddings[i]` in the batch. ```python +losses.SelfSupervisedLoss(loss, symmetric=True, **kwargs) +``` + +**Parameters**: + +* **loss**: The loss function to be wrapped. +* **symmetric**: If `True`, then the embeddings in both `embeddings` and `ref_emb` are used as anchors. If `False`, then only the embeddings in `embeddings` are used as anchors. + +Example usage: + +``` loss_fn = losses.TripletMarginLoss() loss_fn = SelfSupervisedLoss(loss_fn) loss = loss_fn(embeddings, ref_emb) ``` + ??? "Supported Loss Functions" - [AngularLoss](losses.md#angularloss) - [CircleLoss](losses.md#circleloss) diff --git a/src/pytorch_metric_learning/utils/distributed.py b/src/pytorch_metric_learning/utils/distributed.py index 93946eed..ce977224 100644 --- a/src/pytorch_metric_learning/utils/distributed.py +++ b/src/pytorch_metric_learning/utils/distributed.py @@ -34,7 +34,7 @@ def all_gather_embeddings_and_labels(emb, labels): return ref_emb, ref_labels -def gather(emb, labels): +def gather_bak(emb, labels): device = emb.device if labels is not None: labels = c_f.to_device(labels, device=device) @@ -45,6 +45,28 @@ def gather(emb, labels): ) return all_emb, all_labels, labels +def gather(emb, labels): + device = emb.device + if labels is not None: + labels = c_f.to_device(labels, device=device) + # Gather the embeddings from every replica. + emb = c_f.to_device(emb, device=device) + emb_list = [torch.ones_like(emb) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(emb_list, emb) + # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients. + emb_list[torch.distributed.get_rank()] = emb + all_emb = torch.cat(emb_list, dim=0) + + # Gather the labels from every replica. + if labels is not None: + labels_list = [torch.ones_like(labels) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(labels_list, labels) + # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients. + labels_list[torch.distributed.get_rank()] = labels + all_labels = torch.cat(labels_list, dim=0) + else: + all_labels = None + return all_emb, all_labels, labels def gather_emb_and_ref(emb, labels, ref_emb=None, ref_labels=None): all_emb, all_labels, labels = gather(emb, labels) @@ -58,7 +80,17 @@ def gather_emb_and_ref(emb, labels, ref_emb=None, ref_labels=None): def get_indices_tuple(labels, ref_labels, embeddings=None, ref_emb=None, miner=None): device = labels.device - curr_batch_idx = torch.arange(len(labels), device=device) + + + # curr_batch_idx should be the local batch corresponding idx of ref_batch(global batch) + # curr_batch_idx = torch.arange(len(labels), device=device) # this is wrong + + local_bs = len(ref_labels) // torch.distributed.get_world_size() + local_b_start_idx = torch.distributed.get_rank() * local_bs + curr_batch_idx = torch.arange(local_b_start_idx, (local_b_start_idx + local_bs), device=device) + + + if miner: indices_tuple = miner(embeddings, labels, ref_emb, ref_labels) else: @@ -66,12 +98,26 @@ def get_indices_tuple(labels, ref_labels, embeddings=None, ref_emb=None, miner=N return lmu.remove_self_comparisons(indices_tuple, curr_batch_idx, len(ref_labels)) -def gather_enqueue_mask(enqueue_mask, device): +def gather_enqueue_mask_bak(enqueue_mask, device): if enqueue_mask is None: return enqueue_mask enqueue_mask = c_f.to_device(enqueue_mask, device=device) return torch.cat([enqueue_mask, all_gather(enqueue_mask)], dim=0) +def gather_enqueue_mask(enqueue_mask, device): + if enqueue_mask is None: + return enqueue_mask + enqueue_mask = c_f.to_device(enqueue_mask, device=device) + # Gather the enqueue_mask from every replica. + enqueue_mask_list = [torch.ones_like(enqueue_mask) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(enqueue_mask_list, enqueue_mask) + + # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients. + enqueue_mask_list[torch.distributed.get_rank()] = enqueue_mask + + return torch.cat(enqueue_mask_list, dim=0) + + def select_ref_or_regular(regular, ref): return regular if ref is None else ref @@ -123,12 +169,14 @@ def forward_regular_loss( if indices_tuple is None: indices_tuple = get_indices_tuple(labels, all_labels) loss = self.loss(emb, labels, indices_tuple, all_emb, all_labels) + + return loss else: loss = self.loss( all_emb, all_labels, indices_tuple, all_ref_emb, all_ref_labels ) - return loss * world_size + return loss * world_size def forward_cross_batch( self, @@ -152,9 +200,71 @@ def forward_cross_batch( emb, labels, ref_emb, ref_labels ) enqueue_mask = gather_enqueue_mask(enqueue_mask, emb.device) - loss = self.loss(all_emb, all_labels, indices_tuple, enqueue_mask) - return loss * world_size + # print(f'all_gathered emb size: {all_emb.shape}') + # print(f'all_labels emb size: {all_labels.shape}') + # print(f'print enqueue_mask after all gather on {torch.distributed.get_rank()}: {enqueue_mask}') + + # loss = self.loss(all_emb, all_labels, indices_tuple, enqueue_mask) + loss = self.forward_cross_batch_dist_helper(self.loss, all_emb, all_labels, indices_tuple, enqueue_mask) + return loss # unit test has confirmed that this is right. + + def forward_cross_batch_dist_helper(self, loss_inst, embeddings, labels, indices_tuple=None, enqueue_mask=None): + if indices_tuple is not None and enqueue_mask is not None: + raise ValueError("indices_tuple and enqueue_mask are mutually exclusive") + if enqueue_mask is not None: + assert len(enqueue_mask) == len(embeddings) + else: + assert len(embeddings) <= len(loss_inst.embedding_memory) + loss_inst.reset_stats() + device = embeddings.device + labels = c_f.to_device(labels, device=device) + loss_inst.embedding_memory = c_f.to_device( + loss_inst.embedding_memory, device=device, dtype=embeddings.dtype + ) + loss_inst.label_memory = c_f.to_device( + loss_inst.label_memory, device=device, dtype=labels.dtype + ) + + if enqueue_mask is not None: + emb_for_queue = embeddings[enqueue_mask] + labels_for_queue = labels[enqueue_mask] + embeddings = embeddings[~enqueue_mask] + labels = labels[~enqueue_mask] + do_remove_self_comparisons = False + else: + emb_for_queue = embeddings + labels_for_queue = labels + do_remove_self_comparisons = True + + # ==== DDP specific =====# + # get local device emb instead of using all gathered to be efficient + local_bs = len(embeddings)//torch.distributed.get_world_size() + local_b_start_idx = torch.distributed.get_rank() * local_bs + embeddings = embeddings[local_b_start_idx:(local_b_start_idx+local_bs), :] + labels = labels[local_b_start_idx:(local_b_start_idx + local_bs)] + # ==== end DDP specific ======# + + queue_batch_size = len(emb_for_queue) + loss_inst.add_to_memory(emb_for_queue, labels_for_queue, queue_batch_size) + + if not loss_inst.has_been_filled: + E_mem = loss_inst.embedding_memory[: loss_inst.queue_idx] + L_mem = loss_inst.label_memory[: loss_inst.queue_idx] + else: + E_mem = loss_inst.embedding_memory + L_mem = loss_inst.label_memory + + indices_tuple = loss_inst.create_indices_tuple( + embeddings, + labels, + E_mem, + L_mem, + indices_tuple, + do_remove_self_comparisons, + ) + loss = loss_inst.loss(embeddings, labels, indices_tuple, E_mem, L_mem) + return loss class DistributedMinerWrapper(torch.nn.Module): def __init__(self, miner, efficient=False): @@ -172,9 +282,16 @@ def forward(self, emb, labels, ref_emb=None, ref_labels=None): all_emb, all_labels, all_ref_emb, all_ref_labels, labels = gather_emb_and_ref( emb, labels, ref_emb, ref_labels ) + if self.efficient: all_labels = select_ref_or_regular(all_labels, all_ref_labels) all_emb = select_ref_or_regular(all_emb, all_ref_emb) + + # print('in DistributedMinerWrapper: ') + # print(f'all_gathered emb size: {all_emb.shape}') + # print(f'all_labels emb size: {all_labels.shape}, all labels: {all_labels}') + # print(f'labels emb size: {labels.shape}, labels: {labels}') + return get_indices_tuple(labels, all_labels, emb, all_emb, self.miner) else: return self.miner(all_emb, all_labels, all_ref_emb, all_ref_labels) diff --git a/tests/utils/test_distributed.py b/tests/utils/test_distributed.py index baaa71fe..afbc33c1 100644 --- a/tests/utils/test_distributed.py +++ b/tests/utils/test_distributed.py @@ -19,6 +19,7 @@ def parameters_are_equal(model1, model2): output = True for p1, p2 in zip(model1.parameters(), model2.parameters()): + print((p1.data - p2.data).abs().max()) output &= torch.allclose(p1.data, p2.data, rtol=1e-2) return output @@ -92,6 +93,7 @@ def single_process_function( optimizer = optim.SGD(ddp_mp_model.parameters(), lr=lr) original_model = original_model.to(device) + print("assert NOT equal") assert not parameters_are_equal(original_model, ddp_mp_model.module) for i in range(iterations): @@ -122,6 +124,7 @@ def single_process_function( optimizer.step() dist.barrier() + print("assert equal") assert parameters_are_equal(original_model, ddp_mp_model.module) dist.barrier() cleanup() @@ -184,7 +187,7 @@ def loss_and_miner_tester( pass_labels_to_loss_fn=True, use_xbm_enqueue_mask=False, ): - torch.manual_seed(75210) + # torch.manual_seed(75210) loss_kwargs = {} if loss_kwargs is None else loss_kwargs miner_kwargs = {} if miner_kwargs is None else miner_kwargs if TEST_DEVICE == torch.device("cpu"): @@ -205,6 +208,7 @@ def loss_and_miner_tester( original_model = ToyMpModel().type(dtype) model = ToyMpModel().type(dtype) model.load_state_dict(original_model.state_dict()) + print("assert identical") self.assertTrue(parameters_are_equal(original_model, model)) original_model = original_model.to(TEST_DEVICE) @@ -331,8 +335,8 @@ def loss_and_miner_tester( ) def test_distributed_tuple_loss(self): - for xbm in [False, True]: - for use_ref in [False, True]: + for xbm in [True]: + for use_ref in [False]: for use_xbm_enqueue_mask in [False, True]: if xbm and use_ref: continue diff --git a/tests/utils/test_distributed_xbm_queue.py b/tests/utils/test_distributed_xbm_queue.py new file mode 100644 index 00000000..d732abfd --- /dev/null +++ b/tests/utils/test_distributed_xbm_queue.py @@ -0,0 +1,298 @@ +import unittest +from .. import TEST_DEVICE, TEST_DTYPES + + +import os +import argparse +import random +import warnings +import numpy as np + + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import Dataset, DataLoader + +from pytorch_metric_learning import losses, miners +from pytorch_metric_learning.utils import distributed as pml_dist +# def _init_fn(args): +# def func(worker_id): +# SEED=args.seed + worker_id +# torch.manual_seed(SEED) +# torch.cuda.manual_seed(SEED) +# np.random.seed(SEED) +# random.seed(SEED) +# torch.backends.cudnn.deterministic=True +# torch.backends.cudnn.benchmark = False +# return func + +def _init_fn(worker_id): + SEED = worker_id + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + np.random.seed(SEED) + random.seed(SEED) + torch.backends.cudnn.deterministic=True + torch.backends.cudnn.benchmark = False + +def example(local_rank, rank, world_size, args): + print(f"current local_rank is {local_rank}, current rank is {rank}") + + + # ============================DDP specific ===================================# + # create default process group + dist.init_process_group("nccl")# , init_method='env://')#, rank=rank, world_size=world_size) # rank=rank, world_size=world_size) # not needed since aml launch scripts have the EVN VAR set already? + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + warnings.warn( + "You have chosen to seed training. " + "This will turn on the CUDNN deterministic setting, " + "which can slow down your training considerably! " + "You may see unexpected behavior when restarting " + "from checkpoints." + ) + + # dataset + dataset = RandomDataset(args.sample_size) + + print(f"world size is {world_size}") + train_sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + drop_last=True, + seed = args.seed#, + # num_replicas=world_size, # not needed since aml launch scripts have the EVN VAR set already? + # rank=rank + ) + # ============================DDP specific ===================================# + args.batch_size //= world_size + dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers_per_process, pin_memory=True, sampler=train_sampler, worker_init_fn=_init_fn) + + # create local model + model_q = nn.Linear(10, 10).to(local_rank) + with torch.no_grad(): + model_k = nn.Linear(10, 10).to(local_rank) + + #============================DDP specific end ===================================# + # construct DDP model + model_q = DDP(model_q, device_ids=[local_rank]) + + # define loss function and optimizer + # miner = miners.MultiSimilarityMiner() + # set margins to ensure that no pairs are left out for this example + miner = miners.PairMarginMiner(pos_margin=0, neg_margin=100) + # miner = pml_dist.DistributedMinerWrapper(miner=miner, efficient=True) ## miner is encoded inside loss in our case, then the miner will have all gathered embs, so no wrapper is needed + loss_fn = losses.CrossBatchMemory(loss=losses.NTXentLoss(temperature=0.07), embedding_size=10, + memory_size=args.memory_size, miner=miner) + loss_fn = pml_dist.DistributedLossWrapper(loss=loss_fn) + # ============================DDP specific end ===================================# + optimizer = optim.SGD(model_q.parameters(), lr=0.001) + + + ## train loop + # Iterate over the data using the DataLoader + for epoch in range(args.num_of_epoch): + train_sampler.set_epoch(epoch) + for worker_id, index, batch_inputs, batch_outputs in dataloader: + print(f"in epoch {epoch}, worker id {worker_id}, index {index}, label {batch_outputs}") + # forward pass + embeds_Q = model_q(batch_inputs.to(local_rank)) + labels = batch_outputs.to(local_rank) + + # compute output + with torch.no_grad(): # no gradient to keys + copy_params(model_q, model_k, m=0.99) + # ============================DDP specific ===================================# + # shuffle for making use of BN + imgK, idx_unshuffle = batch_shuffle_ddp(batch_inputs.to(local_rank)) + embeds_K = model_k(imgK) + embeds_K = batch_unshuffle_ddp(embeds_K, idx_unshuffle) + # ============================DDP specific end ===================================# + + # ========================== for debug ==================================# + # print(f"input before shuffle on rank {rank}: {batch_inputs}") + # print(f"input after shuffle on rank {rank}: {imgK}") + gpu_idx = torch.distributed.get_rank() + + num_gpus = world_size + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + + # ========================== end for debug ==================================# + ## same as original MOCO: same image augmented with the same label + all_enc = torch.cat([embeds_Q, embeds_K], dim=0) + labels, enqueue_mask = create_labels(embeds_Q.size(0), labels, local_rank) + # # ========================= debug =======================# + # print("======================== all gathering ====================") + # world_size = torch.distributed.get_world_size() + # print(f"world_size is {world_size}") + # all_enc, labels, _, _, _ = pml_dist.gather_emb_and_ref( + # all_enc, labels + # ) + # enqueue_mask = pml_dist.gather_enqueue_mask(enqueue_mask, all_enc.device) + # # ========================= end of debug =======================# + loss = loss_fn(all_enc, labels, + enqueue_mask=enqueue_mask) # miner will be used in loss_fn if initialized with miner + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + + # check pair size + print(f'label on mem bank on rank{rank}: {loss_fn.loss.label_memory}') # here we should see the same labels on mem queue on different workers + print(f"bs is {batch_inputs.shape[0]}") + print("num of pos pairs: ", loss_fn.loss.miner.num_pos_pairs) + print("num of neg pairs: ", loss_fn.loss.miner.num_neg_pairs) + print(f"epoch {epoch}, loss is {loss}") + # print weights + for param in model_q.parameters(): + print(param.data) + + dist.destroy_process_group() + + +def copy_params(encQ, encK, m=None): + if m is None: + for param_q, param_k in zip(encQ.parameters(), encK.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + else: + for param_q, param_k in zip(encQ.parameters(), encK.parameters()): + param_k.data = param_k.data * m + param_q.data * (1.0 - m) + +class RandomDataset(Dataset): + def __init__(self, sample_size): + self.samples = torch.randn(sample_size, 10) + # self.samples = torch.ones_like(self.samples) + self.labels = torch.arange(sample_size) + print(f'data set labels: {self.labels}') + + def __len__(self): + return len(self.samples) + + def __getitem__(self, index): + worker_info = torch.utils.data.get_worker_info() + if not worker_info: + raise NotImplementedError('Not implemented for num_workers=0') + print(f"in worker {worker_info.id}, index {index}, label is {self.labels[index]}") + return worker_info.id, index, self.samples[index], self.labels[index] +def create_labels(num_pos_pairs, labels, device): + # create labels that indicate what the positive pairs are + labels = torch.cat((labels, labels)).to(device) + + # we want to enqueue the output of encK, which is the 2nd half of the batch + enqueue_mask = torch.zeros(len(labels)).bool() + enqueue_mask[num_pos_pairs:] = True + return labels, enqueue_mask +@torch.no_grad() +def batch_shuffle_ddp(x): + """ + Batch shuffle, for making use of BatchNorm. + *** Only support DistributedDataParallel (DDP) model. *** + *** x should be on local_rank device + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = torch.randperm(batch_size_all).to(x.device)# .cuda() + + # broadcast to all gpus + torch.distributed.broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this], idx_unshuffle + +@torch.no_grad() +def batch_unshuffle_ddp(x, idx_unshuffle): + """ + Undo batch shuffle. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this] +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + +class TestDistributedXbmQueue(unittest.TestCase): + def main(self): + # Create the argument parser + parser = argparse.ArgumentParser(description='Train a metric learning model for MRI protocol classification.') + + # Add arguments to the parser + parser.add_argument('--sample_size', type=int, + default=32, + help='batch_size') + parser.add_argument('--batch_size', type=int, + default=32, + help='global batch_size') + parser.add_argument('--seed', type=int, + default=None, + help='world_size') + parser.add_argument('--memory_size', type=int, + default=32, + help='memory_size') + parser.add_argument('--num_workers_per_process', type=int, + default=8, + help='num_workers') + parser.add_argument('--num_of_epoch', type=int, + default=5, + help='num_of_epoch') + + args = parser.parse_args() + ######################################################### + # Get PyTorch environment variables for distributed training. + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # here, why no use of mp.spawn? ==> aml submit script is supposed to take care of the EVN VARs: + # MASTER_ADDR, MASTER_PORT, NODE_RANK, WORLD_SIZE + # RANK, LOCAL_RANK + os.environ['NCCL_DEBUG'] = 'INFO' + print(f"args.batch_size is {args.batch_size}, world_size is {world_size}") + assert (args.batch_size % world_size) == 0 + example(local_rank, rank, world_size, args) + + +if __name__=="__main__": + unittest.main() \ No newline at end of file