Skip to content

This repository explores federated deep generative models with PyTorch, featuring Conditional DCGAN, FedGAN v2, and custom synchronization strategies. It demonstrates client-server training with FedAvg, non-IID data splits, and GAN evaluation, providing a foundation for research in privacy-preserving generative modeling.

Notifications You must be signed in to change notification settings

Fatemerjn/Federated_Deep_Generative_Model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FedGAN: Federated Deep Generative Models

Python 3.8+ PyTorch Code style: black

A comprehensive Python package for Federated Deep Generative Models using conditional GANs. This repository implements state-of-the-art federated learning algorithms for training generative adversarial networks across distributed clients while preserving privacy.

Key Features

  • Federated Learning: Multiple clients train locally; server aggregates using FedAvg or custom strategies
  • Conditional Generation: Class-conditional DCGAN for label-controlled image synthesis
  • Modular Design: Easy-to-extend architecture with pluggable components
  • Comprehensive Evaluation: Built-in metrics including Classification Score, EMD, IS, and FID
  • Flexible Configuration: YAML-based configuration system for easy experimentation
  • Ready-to-Use Notebooks: Three complementary Jupyter notebooks for different use cases
  • Tested & Reliable: Comprehensive test suite ensuring code quality

Repository Structure

federated-deep-generative-models/
|- src/fedgan/              # Core Python package
|  |- models.py             # GAN architectures (cDCGAN, MNIST classifier)
|  |- federated.py          # Federated learning algorithms
|  |- data.py               # Data loading and federated splitting
|  |- metrics.py            # Evaluation metrics
|  |- utils.py              # Utility functions
|  `- config.py             # Configuration management
|- notebooks/               # Jupyter notebooks
|  |- 01_federated_conditional_dcgan.ipynb
|  |- 02_fedgan_v2_kaggle_experiment.ipynb
|  `- 03_custom_synchronization_strategies.ipynb
|- configs/                 # Configuration files
|  |- default_config.yaml
|  |- kaggle_config.yaml
|  `- custom_sync_config.yaml
|- tests/                   # Unit tests
|- docs/                    # Documentation
|- data/                    # Dataset storage (auto-created)
|- outputs/                 # Experiment outputs (auto-created)
|- checkpoints/             # Model checkpoints (auto-created)
|- pyproject.toml           # Project configuration
|- requirements.txt         # Dependencies
`- README.md                # This file

Quick Start

Installation

Option 1: Install from source (recommended)

# Clone the repository
git clone https://github.com/fedgan-team/federated-deep-generative-models.git
cd federated-deep-generative-models

# Create and activate conda environment
conda create -n fedgan python=3.10 -y
conda activate fedgan

# Install PyTorch (adjust for your system)
# For CUDA 11.8:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# For CPU-only:
# pip install torch torchvision torchaudio

# Install the package and dependencies
pip install -e .
# Or install dependencies separately:
# pip install -r requirements.txt

Option 2: Using pip (when available on PyPI)

pip install fedgan

Option 3: Kaggle Environment

For Kaggle users:

  1. Open notebooks/02_fedgan_v2_kaggle_experiment.ipynb in a Kaggle Notebook
  2. Enable GPU acceleration
  3. Run all cells - no additional setup required!

Verify Installation

import fedgan
print(f"FedGAN version: {fedgan.__version__}")

# Quick test
from fedgan import cDCGenerator, cDCDiscriminator
gen = cDCGenerator()
disc = cDCDiscriminator()
print("Installation successful!")

Usage

Using the Python API

import torch
from fedgan import FederatedDataset, cDCGenerator, cDCDiscriminator
from fedgan import federated_cgan_training, EvaluationMetrics
from fedgan.config import load_config

# Load configuration
config = load_config("configs/default_config.yaml")

# Set up federated data
dataset = FederatedDataset("MNIST", root="./data")
client_loaders = dataset.get_federated_loaders(
    num_clients=config.federated.num_clients,
    split_strategy=config.data.split_strategy,
    batch_size=config.data.batch_size
)

# Initialize models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_generator = cDCGenerator().to(device)
global_discriminator = cDCDiscriminator().to(device)

# Run federated training
trained_disc, trained_gen, metrics = federated_cgan_training(
    global_discriminator,
    global_generator,
    client_loaders,
    rounds=config.federated.communication_rounds,
    K=config.federated.clients_per_round,
    local_epochs=config.federated.local_epochs,
    device=device
)

# Evaluate results
evaluator = EvaluationMetrics(trained_classifier, test_loader, device)
results = evaluator.compute_all_metrics(trained_gen)
print(f"Classification Score: {results['classification_score']:.3f}")

Using Jupyter Notebooks

We provide three complementary notebooks for different use cases:

1. Basic Federated Training

Notebook: 01_federated_conditional_dcgan.ipynb

Perfect for beginners! Demonstrates:

  • Conditional DCGAN architecture
  • Federated training with FedAvg
  • Evaluation metrics and visualization
  • Model checkpointing

2. Kaggle-Ready Experiment

Notebook: 02_fedgan_v2_kaggle_experiment.ipynb

Optimized for Kaggle environment:

  • Fast training with reduced parameters
  • Kaggle-friendly paths and logging
  • Competition-ready pipeline

3. Advanced Synchronization

Notebook: 03_custom_synchronization_strategies.ipynb

For researchers and advanced users:

  • Custom aggregation strategies
  • Non-IID data distributions
  • Advanced federated algorithms

Configuration System

FedGAN uses YAML configuration files for easy experimentation:

# configs/my_experiment.yaml
federated:
  num_clients: 20
  communication_rounds: 50
  clients_per_round: 4
  strategy: "fedavg"

data:
  dataset: "MNIST"
  split_strategy: "dirichlet"  # "iid", "dirichlet", "pathological"
  dirichlet_alpha: 0.5

model:
  generator:
    nz: 100
    ngf: 64
  discriminator:
    ndf: 64

training:
  learning_rate: 0.0002
  betas: [0.5, 0.999]

Load and use:

from fedgan.config import load_config
config = load_config("configs/my_experiment.yaml")

Supported Datasets

Dataset Classes Resolution Split Strategies
MNIST 10 28x28 -> 64x64 IID, Non-IID Dirichlet, Pathological
Fashion-MNIST 10 28x28 -> 64x64 IID, Non-IID Dirichlet, Pathological

Note: Additional datasets (CIFAR-10, CelebA) can be easily integrated by extending the FederatedDataset class.

Data Distribution Strategies

  • IID: Data is uniformly distributed across clients
  • Dirichlet: Non-IID distribution using Dirichlet sampling (controllable heterogeneity)
  • Pathological: Each client has data from only a subset of classes (extreme non-IID)

Evaluation Metrics

FedGAN provides comprehensive evaluation metrics for generative models:

Quantitative Metrics

Metric Description Range Better
Classification Score Accuracy of generated samples vs. intended labels [0, 1] Higher
EMD Score Earth Mover's Distance between real and generated distributions [0, infinity) Lower
Inception Score (IS) Quality and diversity of generated samples [1, infinity) Higher
FID Score Frechet Inception Distance (distribution similarity) [0, infinity) Lower

Qualitative Assessment

  • Sample Grids: Visual inspection of generated images per class
  • Training Curves: Loss progression over federated rounds
  • Diversity Analysis: Visual variety within and across classes

Advanced Features

Custom Federated Strategies

Extend the base strategy class to implement novel algorithms:

from fedgan.federated import FederatedStrategy

class FedProxStrategy(FederatedStrategy):
    def __init__(self, mu=0.01):
        super().__init__("FedProx")
        self.mu = mu  # Proximal term coefficient

    def aggregate(self, global_model, local_models):
        # Implement FedProx aggregation logic
        # Add proximal term to local objectives
        pass

Non-IID Data Simulation

# Dirichlet distribution (adjustable heterogeneity)
client_loaders = dataset.get_federated_loaders(
    num_clients=10,
    split_strategy="dirichlet",
    alpha=0.1  # Lower = more non-IID
)

# Pathological split (extreme non-IID)
client_loaders = dataset.get_federated_loaders(
    num_clients=10,
    split_strategy="pathological",
    classes_per_client=2  # Each client sees only 2 classes
)

Configuration Cheatsheet

Common hyperparameters (set near the top of each notebook):

  • num_clients: 5-20 for MNIST demos
  • rounds: 50-200 (increase for higher fidelity)
  • local_epochs: 1-5 (more local steps = fewer comms but may drift)
  • batch_size: 64-256
  • z_dim: 64-128 (latent code)
  • lrG, lrD: typically 2e-4 with Adam betas=(0.5, 0.999) for DCGAN style training

Federated specifics:

  • Aggregator: FedAvg (baseline), custom rule in the Custom Sync notebook
  • Client sampling: all or a subset each round
  • Weighting: by client dataset size (recommended)

Hyperparameter Guidelines

Recommended Starting Values

Parameter MNIST Fashion-MNIST Notes
num_clients 10-20 10-20 More clients = more realistic federation
communication_rounds 50-100 100-200 Increase for better convergence
local_epochs 1-3 1-3 Higher values may cause client drift
batch_size 64-128 64-128 Adjust based on memory constraints
learning_rate 2e-4 2e-4 Classic DCGAN learning rate
nz (noise dim) 100 100 Latent space dimension

Federated-Specific Parameters

  • Client Sampling: Select 20-50% of clients per round for efficiency
  • Synchronization: Enable both generator and discriminator sync for stability
  • Aggregation: Start with FedAvg, experiment with weighted variants

Development & Contributing

Setting up Development Environment

# Clone and install in development mode
git clone https://github.com/fedgan-team/federated-deep-generative-models.git
cd federated-deep-generative-models
pip install -e ".[dev]"

# Install pre-commit hooks
pre-commit install

Running Tests

# Run all tests
pytest

# Run with coverage
pytest --cov=fedgan --cov-report=html

# Run specific test file
pytest tests/test_models.py -v

Troubleshooting

Common Issues

Problem Symptoms Solution
Training Instability Loss oscillations, poor samples Reduce learning rate, increase local epochs
Mode Collapse Generated samples lack diversity Lower discriminator LR, add noise to inputs
Client Drift Performance degrades with more local epochs Reduce local epochs or add regularization
Memory Issues CUDA out of memory Reduce batch size or model dimensions

Performance Tips

  • Use GPU: Training is 10-50x faster on CUDA-enabled devices
  • Batch Size: Larger batches often lead to more stable training
  • Checkpointing: Save models regularly to resume interrupted training
  • Monitoring: Use evaluation metrics to catch problems early

References

Federated Learning

  • McMahan, B., et al. "Communication-efficient learning of deep networks from decentralized data." AISTATS, 2017.
  • Li, T., et al. "Federated optimization in heterogeneous networks." MLSys, 2020.

Generative Adversarial Networks

  • Goodfellow, I., et al. "Generative adversarial nets." NeurIPS, 2014.
  • Radford, A., et al. "Unsupervised representation learning with deep convolutional generative adversarial networks." ICLR, 2016.
  • Mirza, M., & Osindero, S. "Conditional generative adversarial nets." arXiv preprint, 2014.

Evaluation Metrics

  • Salimans, T., et al. "Improved techniques for training GANs." NeurIPS, 2016. (Inception Score)
  • Heusel, M., et al. "GANs trained by a two time-scale update rule converge to a local Nash equilibrium." NeurIPS, 2017. (FID)

Contributing

We welcome contributions! Areas for contribution include:

  • New Datasets: Add support for CIFAR-10, CelebA, etc.
  • Advanced Algorithms: Implement FedProx, SCAFFOLD, FedNova
  • Privacy: Add differential privacy mechanisms
  • Metrics: Implement additional evaluation metrics
  • Documentation: Improve docs and tutorials

About

This repository explores federated deep generative models with PyTorch, featuring Conditional DCGAN, FedGAN v2, and custom synchronization strategies. It demonstrates client-server training with FedAvg, non-IID data splits, and GAN evaluation, providing a foundation for research in privacy-preserving generative modeling.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published