Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7bcfbd0
Initial commit and file setup
Anwealso Oct 11, 2022
f331ab8
Fixed issue where changes were showing up that I didn't make
Anwealso Oct 11, 2022
3805538
Added basic dataloader
Anwealso Oct 11, 2022
77dfc44
Changed the data import pipeline so it uses a pre-downloaded local fo…
Anwealso Oct 11, 2022
7e9a6b5
Fixed dataset loader
Anwealso Oct 11, 2022
343ba1b
Code cleanup
Anwealso Oct 11, 2022
850e47b
First implementation of VQVAE model architecture
Anwealso Oct 11, 2022
d8ada41
Implemented basic training loop and stats logging wrapper class for m…
Anwealso Oct 11, 2022
9456c9e
Changes to readme
Anwealso Oct 11, 2022
b6bfc88
Moved plotting function to utils.py
Anwealso Oct 11, 2022
4893b5c
Added data preprocessing to the data loader
Anwealso Oct 11, 2022
a846d33
Added data variance to items returned by data loader
Anwealso Oct 11, 2022
72c9f06
Verified training loop, added model saving functionality, fixed impor…
Anwealso Oct 11, 2022
e9ee4d9
Fully successfully implemented model saving and loading pipeline!!!
Anwealso Oct 11, 2022
5e31873
Added exensive debug statements to aid in dataset debugging
Anwealso Oct 11, 2022
3582b22
Verified training script with oasis data and pulled hyperparameters u…
Anwealso Oct 11, 2022
9e7dff3
Added new todos to readme
Anwealso Oct 11, 2022
8d2878b
Added model training metric over time plotting
Anwealso Oct 12, 2022
c0d4762
Added example generation image logging throughout training to show pe…
Anwealso Oct 12, 2022
db6fe74
Trained model up for 20 epochs, moved some plotting functions over to…
Anwealso Oct 14, 2022
a5ca629
Implemented average SSIM tracking through training and plotting over …
Anwealso Oct 14, 2022
446d940
Updated and tweaked layout of readme
Anwealso Oct 14, 2022
2019d37
Fixed up data importing on train.py, updated predict.py to actually run
Anwealso Oct 14, 2022
16c7243
Documentation updates and code cleanup
Anwealso Oct 14, 2022
876e59a
Implemented pixel cnn training loop (OOM error on local cpu so moving…
Anwealso Oct 20, 2022
adeea63
Trained vae for 20 epochs and pixelcnn for 60 epochs and created a ne…
Anwealso Oct 21, 2022
9fc7920
Minor documentation tweaks
Anwealso Oct 21, 2022
8488064
Moved generator training over into train.py from generator_train.py
Anwealso Oct 21, 2022
8e90849
Swapped epoch numbers back from test to production values
Anwealso Oct 21, 2022
739d392
Removed some files that shouldn't be committed
Anwealso Oct 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions recognition/45316207_VQ-VAE/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Ignore the dataset folder
keras_png_slices_data/
vqvae_saved_model/
pixelcnn_saved_model/
codebook_indices.csv
.DS_Store
*/.DS_Store
96 changes: 96 additions & 0 deletions recognition/45316207_VQ-VAE/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Synthetic Brain MRI Image Generation with VQ-VAE (COMP3710)

by Alex Nicholson, 45316207

---

## Project Overview

### The Algorithm and the Problem

The algorithm implemented in this project is a [VQ-VAE](https://arxiv.org/abs/1711.00937) (Vector Quantised - Variational Auto-Encoder) model, which is an architecture that aims to encode data into a compressed format (embedding higher dimensional data into a lower dimenisional subspace) and then decode this compressed format to recreate the original image as closely as possible. For this project, we will be training the model on the OASIS brain MRI image datset so that we can use it to generate novel and realistic synthetic brain MRI images.

### How it Works

It works by transforming the image into a set of encoding vectors, using a CNN (convolutional neural network) encoder network, which are then quantised to fit the codebook vectors of the model. These quantised encodings are then passed to the decoder network which is made up of a transposed convolution (deconvolution) layers, which generated a synthetic reconstruction that is very similar to the original input image. This model is then trained until the VQVAE is very accurate at encoding the images into a condensed format while preserving the information held within.

In addition to being able to reconstruct images, we also might want to generate novel brain images, so to do this we can also train a separate CNN, using the PixelCNN architechure that can generate brain images directly from samples of codebook vectors.

### Goals

The performance goals for this project are, generally, for the model to produce a “reasonably clear image” and also, more concretely, for the model to achieve an average structured similarity index (SSIM) of over 0.6.

---

## Usage Guide

### Installation

1. Install Anaconda
2. Create a clean conda environment and activate it
3. Install all of the required packages (see dependancy list below)
4. Download the OASIS dataset from [this link](https://cloudstor.aarnet.edu.au/plus/s/tByzSZzvvVh0hZA/download)

### Usage

* Run `python train.py` to train the model
* Run `python predict.py` to test out the trained model

### Dependancies

The following dependancies were used in the project:

* tensorflow (version 2.9.2)
* tensorflow_probability (version 0.17.0)
* numpy (version 1.23.3)
* matplotlib (version 3.5.1)
* PIL / pillow (version 9.1.0)
* imageio (version 2.22.1)
* skimage (version 0.19.3)

---

## Methods

The training, validation and testing splits of the data were used as provided in the original dataset, with these partitions taking up 85%, 10%, and 5% respectively (total 11,328 images in dataset), in line with good standard practice for dataset partitioning. The data pixel values of the images were normalise to be within -1 to 1 by dividing by 255 and subtracting 1 to avoid data biases.

---

## Results

### Example Generations

Below are some examples of the generations made by the VQ VAE model after 20 epochs of training over the full OASIS training dataset. These generations were produced by putting real MRI image examples from the test set into the model and then getting the reconstructed output from the model.

![alt text](./out/original_vs_reconstructed_0000.png)
![alt text](./out/original_vs_reconstructed_0001.png)
![alt text](./out/original_vs_reconstructed_0002.png)
![alt text](./out/original_vs_reconstructed_0003.png)
![alt text](./out/original_vs_reconstructed_0004.png)
![alt text](./out/original_vs_reconstructed_0005.png)

### Generation Quality Over Time

Below is an animation of the progression of the quality of the model's generations over the course of training.
![alt text](./out/vqvae_training_progression.gif)

### Training Metrics

The various loss metrics of the model were recorded throughout training to track its performance over time, these include:

* Total Loss: What does the total loss represent???
* Reconstruction Loss: What does the reconstruction loss represent???
* VQ VAE Loss: What does the VQ VAE loss represent???

These losses are plotted over the course of the models training in both standard and log scales below:
![alt text](./out/training_loss_curves.png)

Model Log Loss Progress Throughout Training:
![alt text](./out/training_logloss_curves.png)

In addition to statistical losses, a more real world metric to track the quality of our generations over time is to compare the similarity of the reconstructed output images it produces with the original input image they are created from. This similarity can be measured by the SSIM (Structured Similarity Index). At the end of each epoch, the SSIM was computed for 10 randomly selected images from the test dataset, and the average was recorded. This average SSIM can be seen plotted over time below:
![alt text](./training_ssim_curve.png)

---

Made with ❤️
108 changes: 108 additions & 0 deletions recognition/45316207_VQ-VAE/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
dataset.py

Alex Nicholson (45316207)
11/10/2022

Contains the data loader for loading and preprocessing the OASIS data

"""


import glob
import numpy as np
import PIL
import os

import matplotlib.pyplot as plt


def load_dataset(max_images=None, verbose=False):
"""
Loads the OASIS dataset of brain MRI images

Parameters:
(optional) max_images (int): The maximum number of images of the dataset to be used (default=None)
(optional) verbose (bool): Whether a description of the dataset should be printed after it has loaded

Returns:
train_data_scaled (ndarray): Numpy array of scaled image data for training (9,664 images max)
test_dat_scaleda (ndarray): Numpy array of scaled image data testing (1,120 images max)
validate_data_scaled (ndarray): Numpy array of scaled image data validation (544 images max)
data_variance (int): Variance of the test dataset
"""

print("Loading dataset...")

# File paths
images_path = "keras_png_slices_data/"
test_path = images_path + "keras_png_slices_test/"
train_path = images_path + "keras_png_slices_train/"
validate_path = images_path + "keras_png_slices_validate/"
dataset_paths = [test_path, train_path, validate_path]

# Set up the lists we will load our data into
test_data = []
train_data = []
validate_data = []
datasets = [test_data, train_data, validate_data]

# Load all the images into numpy arrays
for i in range(0, len(dataset_paths)):
# Get all the png files in this dataset_path directory
images_list = glob.glob(os.path.join(dataset_paths[i], "*.png"))

images_collected = 0
for img_filename in images_list:
# Stop loading in images if we hit out max image limit
if max_images and images_collected >= max_images:
break

# Open the image
img = PIL.Image.open(img_filename)
# Convert image to numpy array
data = np.asarray(img)
datasets[i].append(data)

# Close the image (not strictly necessary)
del img
images_collected = images_collected + 1

# Convert the datasets into numpy arrays
train_data = np.array(train_data)
test_data = np.array(test_data)
validate_data = np.array(validate_data)

# Preprocess the data
train_data = np.expand_dims(train_data, -1)
test_data = np.expand_dims(test_data, -1)
validate_data = np.expand_dims(validate_data, -1)
# Scale the data into values between -0.5 and 0.5 (range of 1 centred about 0)
train_data_scaled = (train_data / 255.0) - 0.5
test_data_scaled = (test_data / 255.0) - 0.5
validate_data_scaled = (validate_data / 255.0) - 0.5

# Get the dataset variance
data_variance = np.var(train_data / 255.0)

if verbose == True:
# Debug dataset loading
print(f"###train_data ({type(train_data)}): {np.shape(train_data)}###")
print(f"###test_data ({type(test_data)}): {np.shape(test_data)}###")
print(f"###train_data_scaled ({type(train_data_scaled)}): {np.shape(train_data_scaled)}###")
print(f"###test_data_scaled ({type(test_data_scaled)}): {np.shape(test_data_scaled)}###")
print(f"###data_variance ({type(data_variance)}): {data_variance}###")
print('')

print(f"###validate_data ({type(validate_data)}): {np.shape(validate_data)}###")
print(f"###validate_data_scaled ({type(validate_data_scaled)}): {np.shape(validate_data_scaled)}###")

print('')
print('')

return (train_data_scaled, validate_data_scaled, test_data_scaled, data_variance)


if __name__ == "__main__":
# Run a test
load_dataset(max_images=1000)
Loading