A comprehensive Python package for Federated Deep Generative Models using conditional GANs. This repository implements state-of-the-art federated learning algorithms for training generative adversarial networks across distributed clients while preserving privacy.
- Federated Learning: Multiple clients train locally; server aggregates using FedAvg or custom strategies
- Conditional Generation: Class-conditional DCGAN for label-controlled image synthesis
- Modular Design: Easy-to-extend architecture with pluggable components
- Comprehensive Evaluation: Built-in metrics including Classification Score, EMD, IS, and FID
- Flexible Configuration: YAML-based configuration system for easy experimentation
- Ready-to-Use Notebooks: Three complementary Jupyter notebooks for different use cases
- Tested & Reliable: Comprehensive test suite ensuring code quality
federated-deep-generative-models/
|- src/fedgan/ # Core Python package
| |- models.py # GAN architectures (cDCGAN, MNIST classifier)
| |- federated.py # Federated learning algorithms
| |- data.py # Data loading and federated splitting
| |- metrics.py # Evaluation metrics
| |- utils.py # Utility functions
| `- config.py # Configuration management
|- notebooks/ # Jupyter notebooks
| |- 01_federated_conditional_dcgan.ipynb
| |- 02_fedgan_v2_kaggle_experiment.ipynb
| `- 03_custom_synchronization_strategies.ipynb
|- configs/ # Configuration files
| |- default_config.yaml
| |- kaggle_config.yaml
| `- custom_sync_config.yaml
|- tests/ # Unit tests
|- docs/ # Documentation
|- data/ # Dataset storage (auto-created)
|- outputs/ # Experiment outputs (auto-created)
|- checkpoints/ # Model checkpoints (auto-created)
|- pyproject.toml # Project configuration
|- requirements.txt # Dependencies
`- README.md # This file
# Clone the repository
git clone https://github.com/fedgan-team/federated-deep-generative-models.git
cd federated-deep-generative-models
# Create and activate conda environment
conda create -n fedgan python=3.10 -y
conda activate fedgan
# Install PyTorch (adjust for your system)
# For CUDA 11.8:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# For CPU-only:
# pip install torch torchvision torchaudio
# Install the package and dependencies
pip install -e .
# Or install dependencies separately:
# pip install -r requirements.txtpip install fedganFor Kaggle users:
- Open
notebooks/02_fedgan_v2_kaggle_experiment.ipynbin a Kaggle Notebook - Enable GPU acceleration
- Run all cells - no additional setup required!
import fedgan
print(f"FedGAN version: {fedgan.__version__}")
# Quick test
from fedgan import cDCGenerator, cDCDiscriminator
gen = cDCGenerator()
disc = cDCDiscriminator()
print("Installation successful!")import torch
from fedgan import FederatedDataset, cDCGenerator, cDCDiscriminator
from fedgan import federated_cgan_training, EvaluationMetrics
from fedgan.config import load_config
# Load configuration
config = load_config("configs/default_config.yaml")
# Set up federated data
dataset = FederatedDataset("MNIST", root="./data")
client_loaders = dataset.get_federated_loaders(
num_clients=config.federated.num_clients,
split_strategy=config.data.split_strategy,
batch_size=config.data.batch_size
)
# Initialize models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_generator = cDCGenerator().to(device)
global_discriminator = cDCDiscriminator().to(device)
# Run federated training
trained_disc, trained_gen, metrics = federated_cgan_training(
global_discriminator,
global_generator,
client_loaders,
rounds=config.federated.communication_rounds,
K=config.federated.clients_per_round,
local_epochs=config.federated.local_epochs,
device=device
)
# Evaluate results
evaluator = EvaluationMetrics(trained_classifier, test_loader, device)
results = evaluator.compute_all_metrics(trained_gen)
print(f"Classification Score: {results['classification_score']:.3f}")We provide three complementary notebooks for different use cases:
Notebook: 01_federated_conditional_dcgan.ipynb
Perfect for beginners! Demonstrates:
- Conditional DCGAN architecture
- Federated training with FedAvg
- Evaluation metrics and visualization
- Model checkpointing
Notebook: 02_fedgan_v2_kaggle_experiment.ipynb
Optimized for Kaggle environment:
- Fast training with reduced parameters
- Kaggle-friendly paths and logging
- Competition-ready pipeline
Notebook: 03_custom_synchronization_strategies.ipynb
For researchers and advanced users:
- Custom aggregation strategies
- Non-IID data distributions
- Advanced federated algorithms
FedGAN uses YAML configuration files for easy experimentation:
# configs/my_experiment.yaml
federated:
num_clients: 20
communication_rounds: 50
clients_per_round: 4
strategy: "fedavg"
data:
dataset: "MNIST"
split_strategy: "dirichlet" # "iid", "dirichlet", "pathological"
dirichlet_alpha: 0.5
model:
generator:
nz: 100
ngf: 64
discriminator:
ndf: 64
training:
learning_rate: 0.0002
betas: [0.5, 0.999]Load and use:
from fedgan.config import load_config
config = load_config("configs/my_experiment.yaml")| Dataset | Classes | Resolution | Split Strategies |
|---|---|---|---|
| MNIST | 10 | 28x28 -> 64x64 | IID, Non-IID Dirichlet, Pathological |
| Fashion-MNIST | 10 | 28x28 -> 64x64 | IID, Non-IID Dirichlet, Pathological |
Note: Additional datasets (CIFAR-10, CelebA) can be easily integrated by extending the
FederatedDatasetclass.
- IID: Data is uniformly distributed across clients
- Dirichlet: Non-IID distribution using Dirichlet sampling (controllable heterogeneity)
- Pathological: Each client has data from only a subset of classes (extreme non-IID)
FedGAN provides comprehensive evaluation metrics for generative models:
| Metric | Description | Range | Better |
|---|---|---|---|
| Classification Score | Accuracy of generated samples vs. intended labels | [0, 1] | Higher |
| EMD Score | Earth Mover's Distance between real and generated distributions | [0, infinity) | Lower |
| Inception Score (IS) | Quality and diversity of generated samples | [1, infinity) | Higher |
| FID Score | Frechet Inception Distance (distribution similarity) | [0, infinity) | Lower |
- Sample Grids: Visual inspection of generated images per class
- Training Curves: Loss progression over federated rounds
- Diversity Analysis: Visual variety within and across classes
Extend the base strategy class to implement novel algorithms:
from fedgan.federated import FederatedStrategy
class FedProxStrategy(FederatedStrategy):
def __init__(self, mu=0.01):
super().__init__("FedProx")
self.mu = mu # Proximal term coefficient
def aggregate(self, global_model, local_models):
# Implement FedProx aggregation logic
# Add proximal term to local objectives
pass# Dirichlet distribution (adjustable heterogeneity)
client_loaders = dataset.get_federated_loaders(
num_clients=10,
split_strategy="dirichlet",
alpha=0.1 # Lower = more non-IID
)
# Pathological split (extreme non-IID)
client_loaders = dataset.get_federated_loaders(
num_clients=10,
split_strategy="pathological",
classes_per_client=2 # Each client sees only 2 classes
)Common hyperparameters (set near the top of each notebook):
num_clients: 5-20 for MNIST demosrounds: 50-200 (increase for higher fidelity)local_epochs: 1-5 (more local steps = fewer comms but may drift)batch_size: 64-256z_dim: 64-128 (latent code)lrG,lrD: typically2e-4with Adambetas=(0.5, 0.999)for DCGAN style training
Federated specifics:
- Aggregator: FedAvg (baseline), custom rule in the Custom Sync notebook
- Client sampling: all or a subset each round
- Weighting: by client dataset size (recommended)
| Parameter | MNIST | Fashion-MNIST | Notes |
|---|---|---|---|
num_clients |
10-20 | 10-20 | More clients = more realistic federation |
communication_rounds |
50-100 | 100-200 | Increase for better convergence |
local_epochs |
1-3 | 1-3 | Higher values may cause client drift |
batch_size |
64-128 | 64-128 | Adjust based on memory constraints |
learning_rate |
2e-4 | 2e-4 | Classic DCGAN learning rate |
nz (noise dim) |
100 | 100 | Latent space dimension |
- Client Sampling: Select 20-50% of clients per round for efficiency
- Synchronization: Enable both generator and discriminator sync for stability
- Aggregation: Start with FedAvg, experiment with weighted variants
# Clone and install in development mode
git clone https://github.com/fedgan-team/federated-deep-generative-models.git
cd federated-deep-generative-models
pip install -e ".[dev]"
# Install pre-commit hooks
pre-commit install# Run all tests
pytest
# Run with coverage
pytest --cov=fedgan --cov-report=html
# Run specific test file
pytest tests/test_models.py -v| Problem | Symptoms | Solution |
|---|---|---|
| Training Instability | Loss oscillations, poor samples | Reduce learning rate, increase local epochs |
| Mode Collapse | Generated samples lack diversity | Lower discriminator LR, add noise to inputs |
| Client Drift | Performance degrades with more local epochs | Reduce local epochs or add regularization |
| Memory Issues | CUDA out of memory | Reduce batch size or model dimensions |
- Use GPU: Training is 10-50x faster on CUDA-enabled devices
- Batch Size: Larger batches often lead to more stable training
- Checkpointing: Save models regularly to resume interrupted training
- Monitoring: Use evaluation metrics to catch problems early
- McMahan, B., et al. "Communication-efficient learning of deep networks from decentralized data." AISTATS, 2017.
- Li, T., et al. "Federated optimization in heterogeneous networks." MLSys, 2020.
- Goodfellow, I., et al. "Generative adversarial nets." NeurIPS, 2014.
- Radford, A., et al. "Unsupervised representation learning with deep convolutional generative adversarial networks." ICLR, 2016.
- Mirza, M., & Osindero, S. "Conditional generative adversarial nets." arXiv preprint, 2014.
- Salimans, T., et al. "Improved techniques for training GANs." NeurIPS, 2016. (Inception Score)
- Heusel, M., et al. "GANs trained by a two time-scale update rule converge to a local Nash equilibrium." NeurIPS, 2017. (FID)
We welcome contributions! Areas for contribution include:
- New Datasets: Add support for CIFAR-10, CelebA, etc.
- Advanced Algorithms: Implement FedProx, SCAFFOLD, FedNova
- Privacy: Add differential privacy mechanisms
- Metrics: Implement additional evaluation metrics
- Documentation: Improve docs and tutorials