diff --git a/.gitignore b/.gitignore index 6e10f0b1b8..fc1e39bc51 100644 --- a/.gitignore +++ b/.gitignore @@ -132,4 +132,4 @@ dmypy.json .idea/ # no tracking mypy config file -mypy.ini +mypy.ini \ No newline at end of file diff --git a/recognition/44801582_OASIS_VAE/README.md b/recognition/44801582_OASIS_VAE/README.md new file mode 100644 index 0000000000..08d933f405 --- /dev/null +++ b/recognition/44801582_OASIS_VAE/README.md @@ -0,0 +1,70 @@ +# COMP3710 PatternFlow Report +## Alon Nusem - s4480158 +### Project: VQVAE on the Oasis brain dataset + +## Project Overview +### The VQVAE Model +A variational autoencoder is an auto encoder with modifier training that ensures a latent space is generated with better properties than just a regular autoencoder. This can help avoid overfitting and is done by returning a distribution over the latent space and adding a loss function term based on regularisation [1]. + +A vector quantized vae is a form of vae that uses vector quantisation to obtain this discrete latent representation. + +This model in essence is a more effective way to compress, and then uncompress data into a latent space and restore it with signicant accuracy, aiming for a structured similarity over 0.6. + +### The Dateset +The dataset being used for this analysis is the OASIS brain data set, captured during the OASIS brain study. This is an expansive dataset seperated into 544 test images, 1,120 validation images, and 9,664 training images of cross sectional MRI images of brains. + +### The Goal +Using a vector quantized variational autoencoder, the dataset can be analysed and reduced into a more dense latent space. Training this VQVAE allows the model to essentially compress and then uncompress inputs accurately. Following this, a generational network can be designed, and by combining this network, feeding its output into the decoder from the VQVAE, new images can be created from this latent space. + +## Results +After training the VQVAE on a subset of the training dataset, the model was evaluated on an unseen section of the test datset. Below is a sample of 8 brains after being reconstructed from encoding, vector quantization, and decoding after 10 epochs. + +![](samples/reconstruction.png) + +During this run, a SSIM of 0.91 on a sample 500 of the test dataset images. This can be re-evaluated in the predict script. While running the training the following loss plots were produced: + +VQVAE:\ +![](samples/training_loss_curves_vq_vae.png) + +PixelCNN:\ +![](samples/training_loss_curves_pixelcnn.png) + +Both of these come to a plateau which suggests that there likely isn't much that more epochs of training would do. Adding more data may benefit but I touch on this in the final section. + +However while these models seem to train well and VQVAE does function, there must be some issue either with pixelCNN or the generation code as new brains cannot be produced well + +![](samples/Figure_2.png) + +![](samples/Figure_3.png) + +I'm not sure where this implementation went wrong and it requires further analysis but it does illustrate how a low dimensionality code can be transformed into a arguably more brainlike reproduction. + + +## How to setup this project +### Dependancies +- Python 3.9 +- tensorflow=2.10.0 +- tensorflow-probability=0.18.0 +- scikit-image=0.18.1 +- matplotlib-base=3.3.4 +- numpy-base=1.21.5 +- pillow=9.0.1 + +### Steps for reproducing +1. Setup a new conda environment with the dependancies listed above. +2. Download the dataset from https://cloudstor.aarnet.edu.au/plus/s/tByzSZzvvVh0hZA (This is a preprocessed set from the OASIS dataset, it also includes segmentation masks but that isn't necessary for us) +3. Extract the dataset into a folder labelled data in the 44801582_OASIS_VAE directory +4. Use train.py to train vqvae or pixelcnn +5. Use prediction.py to generate new brains (default run uses samples provided, adjustment is needed if you want to use your own model results) + +## How to improve on these results (and issues) +First thing is that the current implementation of dataset is not great. It loads everything into a numpy array which is super space intensive and makes training crash unless you limit the input data size. This needs updating as it would improve training process dramatically. + +The pixelCNN and VQVAE could both be improved, model wise they are not very complex implementations of their base form, better performance is possible. + +## Resources +[1] https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73 + +[] https://github.com/ritheshkumar95/pytorch-vqvae + +[] https://keras.io/examples/generative/vq_vae/ \ No newline at end of file diff --git a/recognition/44801582_OASIS_VAE/dataset.py b/recognition/44801582_OASIS_VAE/dataset.py new file mode 100644 index 0000000000..c3952b35fe --- /dev/null +++ b/recognition/44801582_OASIS_VAE/dataset.py @@ -0,0 +1,35 @@ +import os +import numpy as np +from PIL import Image + + +def load_data(path, img_limit=False, want_var=False): + dataset = [] + + for i, img in enumerate(os.listdir(path)): + if img_limit and i > img_limit: + break + else: + img = Image.open(f"{path}/{img}") + data = np.asarray(img, dtype=np.float32) + dataset.append(data) + + dataset = np.array(dataset, dtype=np.float32) + + if want_var: + data_variance = np.var(dataset / 255.0) + else: + data_variance = None + + dataset = np.expand_dims(dataset, -1) + dataset = (dataset / 255.0) - 0.5 + + return dataset, data_variance + + +def oasis_dataset(images= False): + train, variance = load_data("data/keras_png_slices_data/keras_png_slices_train", images, True) + test, _ = load_data("data/keras_png_slices_data/keras_png_slices_test", images) + validate, _ = load_data("data/keras_png_slices_data/keras_png_slices_validate", images) + + return train, test, validate, variance diff --git a/recognition/44801582_OASIS_VAE/modules.py b/recognition/44801582_OASIS_VAE/modules.py new file mode 100644 index 0000000000..b6b3079d21 --- /dev/null +++ b/recognition/44801582_OASIS_VAE/modules.py @@ -0,0 +1,115 @@ +import tensorflow as tf +from tensorflow import keras +import numpy as np + + +class VectorQuantizer(keras.layers.Layer): + def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs): + super().__init__(**kwargs) + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + + self.beta = beta + + w_init = tf.random_uniform_initializer() + self.embeddings = tf.Variable( + initial_value=w_init(shape=(self.embedding_dim, self.num_embeddings), dtype="float32"), + trainable=True, + name="embeddings_vqvae") + + def call(self, x): + encoding_indices = self.get_code_indices(tf.reshape(x, [-1, self.embedding_dim])) + encodings = tf.one_hot(encoding_indices, self.num_embeddings) + quantized = tf.reshape(tf.matmul(encodings, self.embeddings, transpose_b=True), tf.shape(x)) + + self.add_loss(self.beta * tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2) + + tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)) + + quantized = x + tf.stop_gradient(quantized - x) + + return quantized + + def get_code_indices(self, flattened_inputs): + similarity = tf.matmul(flattened_inputs, self.embeddings) + encoding_indices = tf.argmin((tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True) + + tf.reduce_sum(self.embeddings ** 2, axis=0) - 2 * similarity), axis=1) + + return encoding_indices + + +def def_encoder(latent_dim): + encoder_inputs = keras.Input(shape=(256, 256, 1)) + x = keras.layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")( + encoder_inputs + ) + x = keras.layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x) + encoder_outputs = keras.layers.Conv2D(latent_dim, 1, padding="same")(x) + return keras.Model(encoder_inputs, encoder_outputs, name="encoder") + + +def def_decoder(latent_dim): + latent_inputs = keras.Input(shape=def_encoder(latent_dim).output.shape[1:]) + x = keras.layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")( + latent_inputs + ) + x = keras.layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x) + decoder_outputs = keras.layers.Conv2DTranspose(1, 3, padding="same")(x) + return keras.Model(latent_inputs, decoder_outputs, name="decoder") + + +def VQVAE(latent_dim=16, num_embeddings=64): + vq_layer = VectorQuantizer(num_embeddings, latent_dim, name="vector_quantizer") + encoder = def_encoder(latent_dim) + decoder = def_decoder(latent_dim) + inputs = keras.Input(shape=(256, 256, 1)) + encoded = encoder(inputs) + quantized = vq_layer(encoded) + reconstructions = decoder(quantized) + return keras.Model(inputs, reconstructions, name="vq_vae") + + +class MaskedConvLayer(keras.layers.Layer): + def __init__(self, **kwargs): + super(MaskedConvLayer, self).__init__() + self.convolution = keras.layers.Conv2D(**kwargs) + + def build(self, input_shape): + self.convolution.build(input_shape) + kernel_shape = self.convolution.kernel.get_shape() + self.mask = np.zeros(shape=kernel_shape) + self.mask[: kernel_shape[0] // 2, ...] = 1.0 + self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0 + + def call(self, inputs): + self.convolution.kernel.assign(self.convolution.kernel * self.mask) + return self.convolution(inputs) + + +class ResidBlock(keras.layers.Layer): + def __init__(self, filters, **kwargs): + super(ResidBlock, self).__init__(**kwargs) + self.c1 = keras.layers.Conv2D(filters, 1, activation="relu") + self.mc = MaskedConvLayer(filters=filters // 2, kernel_size=3, activation="relu", padding="same") + self.c2 = keras.layers.Conv2D(filters, 1, activation="relu") + + def call(self, inputs): + x = self.c1(inputs) + x = self.mc(x) + x = self.c2(x) + return keras.layers.add([inputs, x]) + + +def PixelCNN(latent_dim, num_embeddings, num_residual_blocks, num_pixelcnn_layers): + inputs = keras.Input(def_encoder(latent_dim).layers[-1].output_shape, dtype=tf.int32) + encoding = tf.one_hot(inputs, num_embeddings) + x = MaskedConvLayer(filters=128, kernel_size=7, activation="relu", padding="same")(encoding) + + for _ in range(num_residual_blocks): + x = ResidBlock(128)(x) + for _ in range(num_pixelcnn_layers): + x = MaskedConvLayer(filters=128, kernel_size=1, activation="relu", padding="valid")(x) + + output = keras.layers.Conv2D(num_embeddings, 1, 1, padding="valid")(x) + pixel_cnn = keras.Model(inputs, output) + + return pixel_cnn diff --git a/recognition/44801582_OASIS_VAE/predict.py b/recognition/44801582_OASIS_VAE/predict.py new file mode 100644 index 0000000000..fea3d34c0f --- /dev/null +++ b/recognition/44801582_OASIS_VAE/predict.py @@ -0,0 +1,107 @@ +from tensorflow import keras +import tensorflow as tf +import tensorflow_probability as tfp +import numpy as np +import dataset +import modules +import matplotlib.pyplot as plt +from skimage.metrics import structural_similarity + + +def create_new_brains(): + (train_data, validate_data, test_data, data_variance) = dataset.oasis_dataset(images=10) + + vqvae = modules.VQVAE(16, 128) + vqvae.load_weights("samples/vqvae_model_weights.h5") + + pixelcnn = modules.PixelCNN(16, 128, 2, 2) + pixelcnn.load_weights("samples/pixelcnn_model_weights.h5") + + inputs = keras.layers.Input(shape=(64, 64, 16)) + outputs = pixelcnn(inputs, training=False) + categorical_layer = tfp.layers.DistributionLambda(tfp.distributions.Categorical) + outputs = categorical_layer(outputs) + sampler = keras.Model(inputs, outputs) + + batch = 10 + rows = 64 + cols = 64 + priors = np.zeros(shape=(batch, rows, cols)) + + for row in range(rows): + for col in range(cols): + priors[:, row, col] = sampler.predict(priors)[:, row, col] + print(f"{(row + 1)*(col + 1) + (col)}/{64*64}") + + pretrained_embeddings = vqvae.get_layer("vector_quantizer").embeddings + one_hot = tf.one_hot(priors.astype("int32"), 128).numpy() + quantized = tf.reshape(tf.matmul(one_hot.astype("float32"), + pretrained_embeddings, transpose_b=True), (-1, *(64, 64, 16))) + + generated_samples = vqvae.get_layer("decoder").predict(quantized) + + for i in range(batch): + plt.subplot(1, 2, 1) + plt.imshow(priors[i], cmap='gray') + plt.title("Code") + plt.axis("off") + + plt.subplot(1, 2, 2) + plt.imshow(generated_samples[i].squeeze() + 0.5, cmap='gray') + plt.title("Generated Sample") + plt.axis("off") + plt.show() + + +def get_structural_similarity(): + vqvae = modules.VQVAE(16, 128) + vqvae.load_weights("samples/vqvae_model_weights.h5") + _, _, test_data, _ = dataset.oasis_dataset(500) + + similarity_scores = [] + reconstructions_test = vqvae.predict(test_data) + + for i in range(reconstructions_test.shape[0]): + original = test_data[i, :, :, 0] + reconstructed = reconstructions_test[i, :, :, 0] + + similarity_scores.append(structural_similarity(original, reconstructed, + data_range=original.max() - original.min())) + + average_similarity = np.average(similarity_scores) + + print(average_similarity) + + +def plot_reconstructions(): + vqvae = modules.VQVAE(16, 128) + vqvae.load_weights("samples/vqvae_model_weights.h5") + _, _, test_data, _ = dataset.oasis_dataset(500) + + num_tests = 4 + test_images = test_data[np.random.choice(len(test_data), num_tests)] + reconstructions = vqvae.predict(test_images) + + i = 0 + plt.figure(figsize=(num_tests * 2, 4), dpi=512) + for test_image, reconstructed_image in zip(test_images, reconstructions): + test_image = test_image.squeeze() + reconstructed_image = reconstructed_image[:, :, 0] + plt.subplot(num_tests, 2, 2 * i + 1, ) + plt.imshow(test_image, cmap='gray') + plt.title("Original") + plt.axis("off") + + plt.subplot(num_tests, 2, 2 * i + 2) + plt.imshow(reconstructed_image, cmap='gray') + plt.title(f"Reconstructed (SSIM:{structural_similarity(test_image, reconstructed_image, data_range=test_image.max() - test_image.min()):.2f})") + + plt.axis("off") + + i += 1 + + plt.show() + + +if __name__ == "__main__": + plot_reconstructions() \ No newline at end of file diff --git a/recognition/44801582_OASIS_VAE/samples/Figure_1.png b/recognition/44801582_OASIS_VAE/samples/Figure_1.png new file mode 100644 index 0000000000..4079b6c251 Binary files /dev/null and b/recognition/44801582_OASIS_VAE/samples/Figure_1.png differ diff --git a/recognition/44801582_OASIS_VAE/samples/Figure_2.png b/recognition/44801582_OASIS_VAE/samples/Figure_2.png new file mode 100644 index 0000000000..20572bed7b Binary files /dev/null and b/recognition/44801582_OASIS_VAE/samples/Figure_2.png differ diff --git a/recognition/44801582_OASIS_VAE/samples/Figure_3.png b/recognition/44801582_OASIS_VAE/samples/Figure_3.png new file mode 100644 index 0000000000..09ef49ba75 Binary files /dev/null and b/recognition/44801582_OASIS_VAE/samples/Figure_3.png differ diff --git a/recognition/44801582_OASIS_VAE/samples/reconstruction.png b/recognition/44801582_OASIS_VAE/samples/reconstruction.png new file mode 100644 index 0000000000..4f6c2ec24d Binary files /dev/null and b/recognition/44801582_OASIS_VAE/samples/reconstruction.png differ diff --git a/recognition/44801582_OASIS_VAE/samples/training_loss_curves_pixelcnn.png b/recognition/44801582_OASIS_VAE/samples/training_loss_curves_pixelcnn.png new file mode 100644 index 0000000000..ff776f7c02 Binary files /dev/null and b/recognition/44801582_OASIS_VAE/samples/training_loss_curves_pixelcnn.png differ diff --git a/recognition/44801582_OASIS_VAE/samples/training_loss_curves_vq_vae.png b/recognition/44801582_OASIS_VAE/samples/training_loss_curves_vq_vae.png new file mode 100644 index 0000000000..b465c2f14a Binary files /dev/null and b/recognition/44801582_OASIS_VAE/samples/training_loss_curves_vq_vae.png differ diff --git a/recognition/44801582_OASIS_VAE/train.py b/recognition/44801582_OASIS_VAE/train.py new file mode 100644 index 0000000000..70507d7045 --- /dev/null +++ b/recognition/44801582_OASIS_VAE/train.py @@ -0,0 +1,120 @@ +import os +import dataset +import modules +from datetime import datetime +import tensorflow as tf +import matplotlib.pyplot as plt +import numpy as np +from skimage.metrics import structural_similarity + + +class Trainer(tf.keras.models.Model): + def __init__(self, train_variance, latent_dim, num_embeddings, **kwargs): + super(Trainer, self).__init__(**kwargs) + self.train_variance = train_variance + self.latent_dim = latent_dim + self.num_embeddings = num_embeddings + + self.vqvae = modules.VQVAE(self.latent_dim, self.num_embeddings) + + self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss") + self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss") + self.vq_loss_tracker = tf.keras.metrics.Mean(name="vq_loss") + + @property + def metrics(self): + return [ + self.total_loss_tracker, + self.reconstruction_loss_tracker, + self.vq_loss_tracker, + ] + + def train_step(self, x): + with tf.GradientTape() as tape: + reconstructions = self.vqvae(x) + + reconstruction_loss = (tf.reduce_mean((x - reconstructions) ** 2) / self.train_variance) + total_loss = reconstruction_loss + sum(self.vqvae.losses) + + grads = tape.gradient(total_loss, self.vqvae.trainable_variables) + self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables)) + + self.total_loss_tracker.update_state(total_loss) + self.reconstruction_loss_tracker.update_state(reconstruction_loss) + self.vq_loss_tracker.update_state(sum(self.vqvae.losses)) + + return { + "loss": self.total_loss_tracker.result(), + "reconstruction_loss": self.reconstruction_loss_tracker.result(), + "vqvae_loss": self.vq_loss_tracker.result() + } + + +def plot_losses(history, time, name): + plt.figure() + for hist in history.history.keys(): + if hist == "accuracy" or hist == "val_accuracy": + continue + plt.plot(history.history[hist], label=hist) + + plt.title('Loss vs Epoch') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + plt.savefig(f"out/{time}/training_loss_curves_{name}.png") + plt.close() + + +def train_vq(time, num_embeddings, latent_dim, batch_size): + (train_data, validate_data, test_data, data_variance) = dataset.oasis_dataset(10) + + vqvae_trainer = Trainer(data_variance, latent_dim=latent_dim, num_embeddings=num_embeddings) + vqvae_trainer.compile(optimizer=tf.keras.optimizers.Adam()) + + history = vqvae_trainer.fit(train_data, epochs=10, batch_size=batch_size) + + vqvae_trainer.vqvae.save(f"out/{time}/vqvae_model") + vqvae_trainer.vqvae.save_weights(f"out/{time}/vqvae_model_weights.h5") + + plot_losses(history, time, "vq_vae") + + +def train_pixel(time, num_embeddings, latent_dim, batch_size, vqvae_train_path): + (train_data, validate_data, test_data, data_variance) = dataset.oasis_dataset(10) + vqvae = tf.keras.models.load_model(vqvae_train_path) + embeddings = vqvae.get_layer("vector_quantizer").embeddings + + encoded_outputs = vqvae.get_layer("encoder").predict(test_data) + flat_enc_outputs = encoded_outputs.reshape(-1, encoded_outputs.shape[-1]) + + codebook_indices = tf.argmin((tf.reduce_sum(flat_enc_outputs ** 2, axis=1, keepdims=True) + + tf.reduce_sum(embeddings ** 2, axis=0) - 2 + * tf.matmul(flat_enc_outputs, embeddings)), axis=1).numpy().reshape(encoded_outputs.shape[:-1]) + + pixel_cnn = modules.PixelCNN(latent_dim, num_embeddings, 2, 2) + pixel_cnn.compile(optimizer=tf.keras.optimizers.Adam(3e-4), + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=["accuracy"]) + history = pixel_cnn.fit(x=codebook_indices, y=codebook_indices, batch_size=batch_size, + epochs=10, validation_split=0.2) + + plot_losses(history, time, "pixelcnn") + pixel_cnn.save(f"out/{time}/pixelcnn_model") + pixel_cnn.save_weights(f"out/{time}/pixelcnn_model_weights.h5") + + +def main(): + time = datetime.now().strftime('%H:%M:%S') + os.mkdir(f"out/{time}") + + num_embeddings = 128 + latent_dim = 16 + batch_size = 4 + + train_vq(time, num_embeddings, latent_dim, batch_size) + train_pixel(time, num_embeddings, latent_dim, batch_size, "samples/vqvae_model") + + +if __name__ == "__main__": + main()