This repository contains a Pytorch implementation of the paper The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks by Jonathan Frankle and Michael Carbin that can be easily adapted to any model/dataset.
pip3 install -r requirements.txt
python3 main.py --prune_type=lt --arch_type=fc1 --dataset=mnist --prune_percent=10 --prune_iterations=35
--prune_type: Type of pruning- Options : 
lt- Lottery Ticket Hypothesis,reinit- Random reinitialization - Default : 
lt 
- Options : 
 --arch_type: Type of architecture- Options : 
fc1- Simple fully connected network,lenet5- LeNet5,AlexNet- AlexNet,resnet18- Resnet18,vgg16- VGG16 - Default : 
fc1 
- Options : 
 --dataset: Choice of dataset- Options : 
mnist,fashionmnist,cifar10,cifar100 - Default : 
mnist 
- Options : 
 --prune_percent: Percentage of weight to be pruned after each cycle.- Default : 
10 
- Default : 
 --prune_iterations: Number of cycle of pruning that should be done.- Default : 
35 
- Default : 
 --lr: Learning rate- Default : 
1.2e-3 
- Default : 
 --batch_size: Batch size- Default : 
60 
- Default : 
 --end_iter: Number of Epochs- Default : 
100 
- Default : 
 --print_freq: Frequency for printing accuracy and loss- Default : 
1 
- Default : 
 --valid_freq: Frequency for Validation- Default : 
1 
- Default : 
 --gpu: Decide Which GPU the program should use- Default : 
0 
- Default : 
 
- Adding a new architecture :
- For example, if you want to add an architecture named 
new_modelwithmnistdataset compatibility.- Go to 
/archs/mnist/directory and create a filenew_model.py. - Now paste your Pytorch compatible model inside 
new_model.py. - IMPORTANT : Make sure the input size, number of classes, number of channels, batch size in your 
new_model.pymatches with the corresponding dataset that you are adding (in this case, it ismnist). - Now open 
main.pyand go toline 36and look for the comment# Data Loader. Now find your corresponding dataset (in this case,mnist) and addnew_modelat the end of the linefrom archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet. - Now go to 
line 82and add the following to it :Here,elif args.arch_type == "new_model": model = new_model.new_model_name().to(device)new_model_name()is the name of the model that you have given insidenew_model.py. 
 - Go to 
 
 - For example, if you want to add an architecture named 
 - Adding a new dataset :
- For example, if you want to add a dataset named 
new_datasetwithfc1architecture compatibility.- Go to 
/archsand create a directory namednew_dataset. - Now go to /archs/new_dataset/
and add a file namedfc1.py` or copy paste it from existing dataset folder. - IMPORTANT : Make sure the input size, number of classes, number of channels, batch size in your 
new_model.pymatches with the corresponding dataset that you are adding (in this case, it isnew_dataset). - Now open 
main.pyand gotoline 58and add the following to it :Note that as of now, you can only add dataset that are natively available in Pytorch.elif args.dataset == "cifar100": traindataset = datasets.new_dataset('../data', train=True, download=True, transform=transform) testdataset = datasets.new_dataset('../data', train=False, transform=transform)from archs.new_dataset import fc1 
 - Go to 
 
 - For example, if you want to add a dataset named 
 
- Go to 
combine_plots.pyand add/remove the datasets/archs who's combined plot you want to generate (Assuming that you have already executed themain.pycode for those dataset/archs and produced the weights). - Run 
python3 combine_plots.py. - Go to 
/plots/lt/combined_plots/to see the graphs. 
Kindly raise an issue if you have any problem with the instructions.
| fc1 | LeNet5 | AlexNet | VGG16 | Resnet18 | |
|---|---|---|---|---|---|
| MNIST | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | 
| CIFAR10 | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | 
| FashionMNIST | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | 
| CIFAR100 | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | 
Lottery-Ticket-Hypothesis-in-Pytorch
├── archs
│   ├── cifar10
│   │   ├── AlexNet.py
│   │   ├── densenet.py
│   │   ├── fc1.py
│   │   ├── LeNet5.py
│   │   ├── resnet.py
│   │   └── vgg.py
│   ├── cifar100
│   │   ├── AlexNet.py
│   │   ├── fc1.py
│   │   ├── LeNet5.py
│   │   ├── resnet.py
│   │   └── vgg.py
│   └── mnist
│       ├── AlexNet.py
│       ├── fc1.py
│       ├── LeNet5.py
│       ├── resnet.py
│       └── vgg.py
├── combine_plots.py
├── dumps
├── main.py
├── plots
├── README.md
├── requirements.txt
├── saves
└── utils.py
Parts of code were borrowed from ktkth5.
Open a new issue or do a pull request incase you are facing any difficulty with the code base or if you want to contribute to it.
