diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 6e10f0b1b8..0000000000 --- a/.gitignore +++ /dev/null @@ -1,135 +0,0 @@ -recognition/s4481540_Zhuoxiao_Chen/data/ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so -.idea -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# celery beat schedule file -celerybeat-schedule - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# vscode config file -.vscode/ - -# pycharm project settings -.idea/ - -# no tracking mypy config file -mypy.ini diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/.Rhistory b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/.Rhistory new file mode 100644 index 0000000000..e69de29bb2 diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/.gitignore b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/.gitignore new file mode 100644 index 0000000000..4981c7f217 --- /dev/null +++ b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/.gitignore @@ -0,0 +1,11 @@ +Makefile +plan.txt +s_run_script.sh +slurm.sh +trainer.weights.h5 +w-gan.py +venv/ +data/ +.DS_Store +trained_models/ +plot_history.py \ No newline at end of file diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/pixel-cnn-generator.py b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/pixel-cnn-generator.py new file mode 100644 index 0000000000..ac42254cb3 --- /dev/null +++ b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/pixel-cnn-generator.py @@ -0,0 +1,334 @@ +""" +Based on PixelCNN by ADMoreau available at +https://keras.io/examples/generative/pixelcnn/ + +PixelCNN to mimic latent space of the encoders output. For generation +""" + +import keras +import numpy as np +from keras import layers + +import tensorflow as tf +# from keras import mixed_precision +# mixed_precision.set_global_policy("mixed_float16") # speed up + + +# for displaying +import matplotlib.pyplot as plt +import math + +import importlib +vqvae = importlib.import_module("vq-vae") + +class MaskConstraint(keras.constraints.Constraint): + def __init__(self, mask): self.mask = tf.constant(mask, dtype=tf.float32) + def __call__(self, w): return w * self.mask + +class PixelConvLayer(layers.Layer): + + """ + a convolution layer that masks the kernel to only influence pixels behind + the current pixel - allows for conditional generation of the output + """ + + def __init__( + self, + mask_type, # A includes the pixel itself, B does not + **kwargs, # arguments for convolution layer + ): + super().__init__() + self.mask_type=mask_type + self.conv = layers.Conv2D(**kwargs) + + def build( + self, input_shape + ): + self.conv.build(input_shape) + k = self.conv.kernel + + # get dimensions + kh, kw, cin, cout = k.shape + + mask = np.zeros((kh, kw, cin, cout), dtype=np.float32) + mask[: kh // 2, :, :, :] = 1.0 # rows above + mask[kh // 2, : kw // 2, :, :] = 1.0 # same row, left + + if self.mask_type == "B": + mask[kh // 2, kw // 2, ...] = 1.0 # center (only for B) + + self.mask = tf.constant(mask, dtype=tf.float32) + # Hard-apply once at init to eliminate any initial leakage + self.conv.kernel.assign(self.conv.kernel * self.mask) + # Keep enforcing after each optimizer step + self.conv.kernel_constraint = MaskConstraint(self.mask) + + print(self.mask_type, self.mask[..., 0, 0][k.shape[0]//2]) + def call( + self, x + ): + # mask the kernel with the mask + self.conv.kernel.assign(self.conv.kernel * self.mask) + + # return + return self.conv(x) + + +def residual_pixel_layer( + x, + num_filters, +): + + """ + a residual layer with 2 1x1 convolutions around a 3x3 pixel convolution. + """ + start = x + + x = layers.Conv2D( + num_filters, 1, 1 + )(x) + x = keras.activations.relu(x) + + x = PixelConvLayer( + # "B", filters=num_filters//2, kernel_size=3, padding="same" + "B", filters=num_filters, kernel_size=3, padding="same" + )(x) + + x = keras.activations.relu(x) + + + x = layers.Conv2D( + num_filters, 1, 1 + )(x) + + x = keras.layers.add([x, start]) + x = keras.activations.relu(x) + + return x + + +def pixel_cnn_model( + input_shape, + num_residuals, + num_embeddings, + num_filters=128 +): + """ + returns a pixelCNN model + """ + + inputs = layers.Input(shape=input_shape, dtype=tf.int32) + # x = tf.cast(inputs, tf.float32) / 511 + # one hot + # x = layers.Embedding(num_embeddings, 3)(tf.squeeze(inputs, -1)) + x = tf.one_hot(tf.squeeze(inputs, -1), num_embeddings) + + # initial convolution + x = PixelConvLayer( + "A", + filters=num_filters, + kernel_size=7, + padding="same" + )(x) + + x = keras.activations.relu(x) + + for _ in range(num_residuals): + x = residual_pixel_layer( + x, num_filters + ) + + # end on 2 more pixelconvolutions + for _ in range(2): + x = PixelConvLayer( + "B", filters=num_filters, kernel_size=1, strides=1, padding='valid' + )(x) + + output = layers.Conv2D( + filters=num_embeddings, + kernel_size=1, + strides=1, + padding="valid" + )(x) + + return keras.Model(inputs=inputs, outputs=output) + +def generate_images( + pixelcnn, + quantizer, + decoder, + image_shape, # (rows, cols) + num_samples=16, +): + + rows, cols = image_shape + batch_size = num_samples + num_embeddings = quantizer.num_embeddings + + # initialise latent index grid + # priors = np.random.random_integers(0, num_embeddings-1, (batch_size, rows, cols, 1)) + priors = np.zeros((batch_size, rows, cols, 1)) + + for r in range(rows): + for c in range(cols): + logits = pixelcnn(priors) + logits_rc = logits[:, r, c, :] / 0.7 + sample = tf.random.categorical(logits_rc, 1) + priors[:, r, c, 0] = tf.squeeze(sample, -1).numpy().astype(np.int32) + + # back to one-hot + priors = priors.squeeze(-1) + indices = tf.one_hot(priors, num_embeddings) + + codebook = tf.convert_to_tensor(quantizer.codebook, dtype=tf.float32) + quantized = tf.matmul(indices, tf.transpose(codebook)) + + # decode to image space + generated = decoder(quantized, training=False) + + return generated + + +# pass image dataset through new pipeline +def dataset_pipeline(img): + # add batch and channel dimension + img = tf.expand_dims(img, axis=0) # batch + img = tf.expand_dims(img, axis=-1) # channel + + prediction = encoder(img, training=False) + flattened = tf.reshape(prediction, (-1, prediction.shape[-1])) + + # quantize + indices = quantizer.get_code_indices(flattened) + indices = tf.cast(indices, tf.int32) + indices = tf.reshape(indices, prediction.shape[1:-1]) + indices = tf.expand_dims(indices, -1) # channel dim + + return (indices, indices) + + +def load_model(path): + shape = ( + vqvae.get_dataset(vqvae.TRAIN_FOLDER) + .map(dataset_pipeline) + .element_spec[0].shape#[1:] + ) + + new_pixel = pixel_cnn_model( + shape, + 4, + vqvae.CODEBOOK_SIZE, + num_filters=32 + ) # new model with same parameters + + + new_pixel.compile("adam") + new_pixel.build(shape) + new_pixel.load_weights(path) + + return new_pixel + + +def show_generated_images(batch, title="Generated_Images"): + """ + Display a batch of generated grayscale images as a square grid. + + Args: + batch: np.ndarray or tf.Tensor of shape (N, H, W, 1) or (N, H, W) + N = number of images + title: optional string for figure title + save file name + """ + + # Convert to numpy + batch = np.array(batch) + + # Drop channel dimension if present + if batch.ndim == 4 and batch.shape[-1] == 1: + batch = batch[..., 0] + + n = batch.shape[0] + grid_size = math.ceil(math.sqrt(n)) + + fig, axes = plt.subplots(grid_size, grid_size, figsize=(grid_size * 2, grid_size * 2)) + axes = axes.flatten() + + for i, ax in enumerate(axes): + ax.axis("off") + if i < n: + ax.imshow(batch[i], cmap="gray") + plt.suptitle(title) + plt.tight_layout() + plt.savefig(title+".jpg") + + +if __name__ == "__main__": + # load existing vqvae model + vae_model = vqvae.load_model( + "trainer.weights.h5" + ) + + # get components + encoder = vae_model.get_layer("encoder") + decoder = vae_model.get_layer("decoder") + quantizer = vae_model.get_layer("quantizer") + + print("Codebook shape:", quantizer.codebook.shape) + + + # load the dataset if it already exists + dataset_save_path = "pixelcnn_dataset" + try: + pixelcnn_dataset = tf.data.Dataset.load(dataset_save_path) + except: + print("Dataset not found, creating now...") + # create pixelcnn dataset + image_dataset = vqvae.get_dataset( + vqvae.TRAIN_FOLDER + ) + + pixelcnn_dataset = image_dataset.map(dataset_pipeline) + pixelcnn_dataset = pixelcnn_dataset.cache().batch(64) + + # save so it isnt used again + pixelcnn_dataset.save(dataset_save_path) + + # actually need to remove channel dim on Y + pixelcnn_dataset = pixelcnn_dataset.map(lambda x, y: (x, tf.squeeze(y, -1))) + pixelcnn_dataset = pixelcnn_dataset.shuffle(pixelcnn_dataset.cardinality()) + + # initialise pixelcnn object + io_shape = pixelcnn_dataset.element_spec[0].shape[1:] # it adds batch dim auto + pcnn = pixel_cnn_model( + io_shape, 8, vqvae.CODEBOOK_SIZE, num_filters=256 + ) + + pcnn.compile( + keras.optimizers.Adam( + # dtype_policy="mixed_float16", + learning_rate=0.001, + clipnorm=1.0 + ), + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[keras.metrics.SparseCategoricalAccuracy()] + ) + + # load in existing weights + # pcnn = load_model("pixel.weights.h5") + + pcnn.fit( + pixelcnn_dataset, + epochs=200 + ) + + generated_outputs = generate_images( + pcnn, quantizer, decoder, io_shape[:-1] + ) + + show_generated_images(generated_outputs, title="Novel_Generated_Outputs_3") + + # save model + pcnn.save_weights( + "pixel_round_2.weights.h5" + ) + + diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/Novel_Generated_Outputs_3_better.jpg b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/Novel_Generated_Outputs_3_better.jpg new file mode 100644 index 0000000000..ac0622922a Binary files /dev/null and b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/Novel_Generated_Outputs_3_better.jpg differ diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/Novel_Generated_Outputs_working.jpg b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/Novel_Generated_Outputs_working.jpg new file mode 100644 index 0000000000..702cd72af9 Binary files /dev/null and b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/Novel_Generated_Outputs_working.jpg differ diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/loss.png b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/loss.png new file mode 100644 index 0000000000..ecbdcc83b3 Binary files /dev/null and b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/loss.png differ diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/model_architecture.png b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/model_architecture.png new file mode 100644 index 0000000000..58680e924b Binary files /dev/null and b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/model_architecture.png differ diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/reconstruction.jpg b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/reconstruction.jpg new file mode 100644 index 0000000000..e674c076fe Binary files /dev/null and b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/reconstruction.jpg differ diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/reconstruction_good.jpg b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/reconstruction_good.jpg new file mode 100644 index 0000000000..3528ff73a0 Binary files /dev/null and b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/plots/reconstruction_good.jpg differ diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/readme.md b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/readme.md new file mode 100644 index 0000000000..d42e1754c9 --- /dev/null +++ b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/readme.md @@ -0,0 +1,46 @@ +## VQ VAE and PixelCNN Implementation + +This completes problem 10 of the COMP3710 (Pattern Recognition) assignment 2025. + +All requirements for this project are listed in requirements.txt + +#### VQ-VAE + +In vq-vae.py, there is the implementation in tf/keras of a vector quantised variational autoencoder. Calling this implementation uses a default architecture designed by me and takes only as hyperparameters the latent dimension (of embeddings) and the number of embeddings (ie codebook size) + +![model architecture](plots/model_architecture.png) + +The encoder consists of 3 residual blocks each doubling in filters from 128 while halving in spacial dimension. The quantizer takes the embeddings and maps them to the closest of the 512 codebook vectors and the decoder performs the exact operations of the encoder but in reverse (with transpose convolutions). This model was trained for 10 epochs. + +![loss](plots/loss.png) + +The end result are the following reconstructions: + +![reconstruction 1](plots/reconstruction_good.jpg) + +![reconstruction 1](plots/reconstruction.jpg) + +This was tested in ssim-evalution.py and it scored 0.78039 structured similarity across the test set. + +This is with 3 halvings in spacial dimension and a latent dimension of 3 corresponding to roughly 21x compression in latent representation (128x256 in image space and 16x32x3 in latent). + +#### PixelCNN + +Then, in pixel-cnn-generator.py a pixelcnn model was trained on the latent distribution. Attempting to conditionally predict latent vector indices. + +The PixelCNN has structure of an initial pixelConvolution layer followed by 8 residual pixelConvolution layers and finally 2 more pixel convolutions. A pixel convolution is a standard convolution with all kernel entries occuring at and after the current pixel being zeroed out. This is what makes the pixelCNN conditional, it may only determine the current output entry by what has come before it. A residual pixel layer is a pixelConvolution inbetween 2 regular convolutions with a skip connection across all. + +This PixelCNN trained for ~50 epochs on the latent representation output by the encoder and converged (on the best run) with a loss of ~2.7 + +unlike other models the convergent error of a pixelcnn scales with the codebook size and so this loss is indicative of a decent model but one still with possible improvement. A perfect model would be closer to 2 for the used codebook size. + +then this model can be used in the generate_images function to sample this learned distribution. The sampled indices are selected from the code book and run through the decoder to produce novel examples. + +First the examples were very rough +![rough](plots/Novel_Generated_Outputs_working.jpg) + +But then by adding one-hot encoding to input, adjusting the temperature of sampling and tweaking the model architecture (to that described above), they could be improved to the following + +![better](plots/Novel_Generated_Outputs_3_better.jpg) + +These have clear features present in hip MRI scans shown above. \ No newline at end of file diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/requirements.txt b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/requirements.txt new file mode 100644 index 0000000000..6ac0faa5db --- /dev/null +++ b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/requirements.txt @@ -0,0 +1,8 @@ +tensorflow==2.15 +tensorflow_probability +keras==2.15 +tf-keras==2.15 +numpy +scikit-image +nibabel +matplotlib \ No newline at end of file diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/ssim-evaluation.py b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/ssim-evaluation.py new file mode 100644 index 0000000000..e71c312745 --- /dev/null +++ b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/ssim-evaluation.py @@ -0,0 +1,54 @@ +import tensorflow as tf + +import importlib +vqvae = importlib.import_module("vq-vae") + +""" +EVALUATED: 0.78039 +""" + +def evaluate_ssim(model, test_dataset, max_val=1.0): + """ + Compute average SSIM over a test dataset for a VQ-VAE model. + + Args: + model: a trained VQ-VAE model with .predict or __call__ method + test_dataset: a tf.data.Dataset yielding batches of images in [0, 1] range + max_val: maximum pixel value (default assumes images scaled to [0,1]) + + Returns: + float: average SSIM score across the dataset + """ + ssim_scores = [] + + for batch in test_dataset: + # Ensure batch is float32 + batch = tf.cast(batch, tf.float32) + + # Get reconstructions from model + recon = model(batch, training=False) + + # Compute SSIM for each image in the batch + batch_ssim = tf.image.ssim(batch, recon, max_val=max_val) + + # Collect as numpy values + ssim_scores.extend(batch_ssim.numpy()) + + return float(tf.reduce_mean(ssim_scores).numpy()) + +if __name__ == "__main__": + # fetch test set + test = vqvae.get_dataset( + # "data/keras_slices_data/keras_slices_train/" + "/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_test" + ).batch(10) + + vae = vqvae.load_model( + "trainer.weights.h5" + ) + + ssim = evaluate_ssim( + vae, test + ) + + print(ssim) \ No newline at end of file diff --git a/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/vq-vae.py b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/vq-vae.py new file mode 100644 index 0000000000..3512ce72e6 --- /dev/null +++ b/recognition/Hip_MRI_VQVAE_PixelCNN_48036177/vq-vae.py @@ -0,0 +1,522 @@ +import tensorflow as tf +import keras +from keras import layers + +# for data +import os +import nibabel as nib +import numpy as np + +# testing +import matplotlib.pyplot as plt +from skimage.transform import resize +import json + +""" +Based on example: 'Vector-Quantized Variational Autoencoders' by Sayak Paul +href: https://keras.io/examples/generative/vq_vae/ +""" + +""" +IDEAS: Could increase the number of hidden units across the board +""" + +# TRAIN_FOLDER = "data/keras_slices_data/keras_slices_train/" +TRAIN_FOLDER = "/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_train" +HISTORY_PATH = "history.json" +LATENT_DIM = 3 +CODEBOOK_SIZE = 512 + +class VectorQuantizerLayer(layers.Layer): + + def __init__( + self, + num_embeddings, # number of vectors in codebook + embedding_dim, + commitment_weight=0.25, # coefficient of commitment loss + **kwargs + ): + super().__init__(**kwargs) + + self.commitment_weight=commitment_weight + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + + # Initialise codebook + self.codebook = self.add_weight( + name="VQ-CODEBOOK", + shape=(self.embedding_dim, self.num_embeddings), + initializer=tf.random_uniform_initializer(), + trainable=True + ) + + + def call(self, x): + """ + Given a layer of network, replace each vector in the layer with a + vector from the codebook - quantize the layer + """ + + # store starting shape + start_shape = tf.shape(x) + + # flatten to single row of embedding_dim vectors + flattened = tf.reshape(x, [-1, self.embedding_dim]) + + # actually quantize + # indices of closest code vectors + indices = self.get_code_indices(flattened) + + # one-hot the needed vector + one_hot_indices = tf.one_hot(indices, self.num_embeddings) + + # extract the given index from the codebook - one hot selects correct + # codebook vector + quantized = tf.matmul(one_hot_indices, self.codebook, transpose_b=True) + quantized = tf.reshape(quantized, start_shape) + + # loss definitions + loss_commitment = tf.reduce_mean( + # dont let this loss flow to the codebook - only to update encoder + # so detach with stop_gradient + + # MSE to get encoder to give vectors close to codebook vectors + (tf.stop_gradient(quantized) - x)**2 + ) + + loss_codebook = tf.reduce_mean( + # for codebook only, stop gradient from flowing to encoder + (quantized - tf.stop_gradient(x)) ** 2 + ) + + # add to model + loss_total = self.commitment_weight * loss_commitment + loss_codebook + self.add_loss(loss_total) + + # trick to stop gradient flowing from rest of model going through + # the codebook - quantized is gradient stopped but we're still passing + # the quantized output -- gradient flows through x + gradient_skipped_quantized = x + tf.stop_gradient(quantized - x) + return gradient_skipped_quantized + + def get_code_indices(self, flattened): + """ + calculate L2 distance to all codebook vectors and return the index + of the closest one for each vector in flattened + """ + + ab = tf.matmul( + flattened, self.codebook + ) + + distances = (#L2 + tf.reduce_sum(flattened ** 2, axis=1, keepdims=True) + + tf.reduce_sum(self.codebook ** 2, axis=0) + -2*ab + ) + + smallest_indices = tf.argmin( + distances, axis=1 + ) + + return smallest_indices + + +def residual_block(x, num_filters): + """ + standard residual block. used for encoder and decoder both + """ + start = x + x = keras.layers.Conv2D( + num_filters, 3, 1, padding="same" + )(x) + + x = keras.activations.relu(x) + + x = keras.layers.Conv2D( + num_filters, 1, 1, padding="same" + )(x) + + x = keras.layers.Add()([x, start]) + x = keras.activations.relu(x) + + return x + + +def get_encoder( + input_shape, + layers_filters, + latent_dim=64, + ): + """ + returns an encoder + + layers: list of kernal sizes to iterate over + """ + + inputs = keras.Input(shape=input_shape) + + x = inputs + for num_filters in layers_filters: + x = keras.layers.Conv2D( + num_filters, + 4, + strides=2, + padding="same" + )(x) + + x = keras.activations.relu(x) + + # down sampling followed by residual blocks + x = residual_block(x, num_filters) + x = residual_block(x, num_filters) + + # use 1x1 kernel to convert to latent dim size + outputs = keras.layers.Conv2D( + latent_dim, 1, padding="same" + )(x) + + return keras.Model(inputs, outputs, name='encoder') + +def get_decoder( + layers_filters, + input_shape, # shape of encoder output + latent_dim = 64 +): + """ + layers_filters should be the same as the encoder but reversed. + + Returns decoder object. + """ + + inputs = keras.Input( + shape=input_shape + ) + x = inputs + for num_filters in layers_filters: + + x = keras.layers.Conv2DTranspose( + num_filters, 4, strides=2, padding="same" + )(x) + + x = keras.activations.relu(x) + + x = residual_block(x, num_filters) + x = residual_block(x, num_filters) + + # 1x1 kernel to bring to a single layer + x = layers.Conv2DTranspose(1, 3, padding="same")(x) + outputs = keras.activations.sigmoid(x) # scale to [0,1] + return keras.Model( + inputs, outputs, name="decoder" + ) + +def vq_vae( + input_shape, + latent_dim=64, + num_embeddings=128, +): + quantize_layer = VectorQuantizerLayer( + num_embeddings, latent_dim, name="quantizer" + ) + + filter_layers = [ + 128, 256, 512 + ] + + encoder = get_encoder( + latent_dim=latent_dim, + input_shape=input_shape, + layers_filters=filter_layers + ) + + decoder = get_decoder( + filter_layers[::-1], # reverse order to encoder to end in same shape + encoder.output.shape[1:], + latent_dim=latent_dim + ) + + inputs = keras.Input(shape=input_shape) + encoder_out = encoder(inputs) + quantized_out = quantize_layer(encoder_out) + reconstruction = decoder(quantized_out) + + return keras.Model( + inputs, reconstruction, name="vq-vae" + ) + +class trainer(keras.models.Model): + """ + Class to train a vq-vae + """ + def __init__( + self, + train_variance, + latent_dim, + codebook_size, + input_shape, + **kwargs + ): + # init + super().__init__(**kwargs) + self.train_variance = train_variance + self.latent_dim = latent_dim + self.codebook_size = codebook_size + + + # actual model + self.vqvae = vq_vae( + input_shape, latent_dim, codebook_size + ) + + # initialise losses - all mean to average over batch + self.total_loss = keras.metrics.Mean("total_loss") + self.reconstruction_loss = keras.metrics.Mean("reconstruction_loss") + self.vqvae_loss = keras.metrics.Mean("vqvae_loss") + + def call(self, x): + return self.vqvae(x) + + # override metrics so our custom metrics are reset on epochs/tracked + @property + def metrics(self): + return [ + self.total_loss, self.reconstruction_loss, self.vqvae_loss + ] + + def train_step(self, x): + + with tf.GradientTape() as tape: + # get output + reconstruction = self.vqvae(x) + + # calculate loss - reconstrution normalised to dataset variance + recon_loss = tf.reduce_mean((x-reconstruction)**2) #/ self.train_variance + + total_loss = ( + recon_loss + + sum(self.vqvae.losses) + ) + + # now backpropegate + grad = tape.gradient( + # derivative of total_loss w.r.t all trainable variables + total_loss, self.vqvae.trainable_variables + ) + self.optimizer.apply_gradients( # run gradient descent + zip(grad, self.vqvae.trainable_variables) + ) + + # track losses + self.total_loss.update_state(total_loss) + self.reconstruction_loss.update_state(recon_loss) + self.vqvae_loss.update_state(sum(self.vqvae.losses)) + + # return result + return { + "loss" : self.total_loss.result(), + "reconstruction loss": self.reconstruction_loss.result(), + "vqvae loss": self.vqvae_loss.result() + } + +def __NUMPY__load_image_to_tensor(filepath: str): + """ + loads the image at filepath and returns it as a greyscale numpy array + with values between [0,1] + """ + filepath = filepath.decode("utf-8") + # load into greyscale + img = nib.load(filepath) + + # as np array + data = img.get_fdata().astype(np.float32) + zoom = img.header.get_zooms() + + # Normalize to [0,1] + data = (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-8) + + data_resized = resize( + data, (256, 128), order=1, preserve_range=True, anti_aliasing=True + ).astype(np.float32) + + # expand channels on end to 1 + data = np.expand_dims(data_resized, axis=-1) + + return data # returns as numpy + +def load_image_to_tensor(filepath): + data = tf.numpy_function( # execute and cast to tensor + __NUMPY__load_image_to_tensor, [filepath], tf.float32 + ) + # define the shape + data.set_shape([256, 128, 1]) + print(data) + + return data + + +def get_dataset(folder): + train_files = [ + os.path.join(folder, f) + for f in os.listdir(folder) + if f.endswith(".nii.gz") + ] + + train_data = tf.data.Dataset.from_tensor_slices( + train_files + ) + + train_files = train_data.map(load_image_to_tensor) + + return train_files + +def get_variance(dataset): + """ + Compute variance of all elements in a tf.data.Dataset. + Assumes dataset yields numeric tensors (scalars or arrays). + """ + # First pass: compute mean + count = 0 + total = 0.0 + for x in dataset: + total += tf.reduce_sum(tf.cast(x, tf.float32)) + count += tf.size(x).numpy() + mean = total / count + + # Second pass: compute squared differences + sq_diff_sum = 0.0 + for x in dataset: + sq_diff_sum += tf.reduce_sum((tf.cast(x, tf.float32) - mean) ** 2) + + variance = sq_diff_sum / count + return variance.numpy() + + +def get_image_shape(): + """ + returns the shape of an image in the train dataset + """ + # fetch the shape of the output + for batch in get_dataset(TRAIN_FOLDER): + img_shape = batch.shape + return img_shape + + +def plot_reconstruction(original, reconstruction): + """ + Plots side by side the two greyscale tensors + """ + fig, axes = plt.subplots(1, 2, figsize=(10, 5)) + + axes[0].imshow(original) + axes[1].imshow(reconstruction) + + plt.tight_layout() + plt.savefig("reconstruction.jpg") + + +def main(): + + train_variance = get_variance( + get_dataset(TRAIN_FOLDER) + ) + + train_dataset = get_dataset(TRAIN_FOLDER).batch(128) + + data_augmentation = keras.Sequential([ + layers.RandomRotation(0.05), # 10% rotation + layers.RandomZoom(0.1), # zoom in/out + layers.RandomTranslation(0.05, 0.05), # shift + layers.RandomContrast(0.1), + ]) + + train_dataset.map(lambda x: data_augmentation(x, training=True)) + + img_shape = get_image_shape() + print(f"Image shape: {img_shape}") + + # actuall train the model + + vq_trainer = trainer( + train_variance, LATENT_DIM, CODEBOOK_SIZE, img_shape + ) + + vq_trainer.compile( + optimizer=keras.optimizers.Adam(learning_rate=0.001), + ) + + history = vq_trainer.fit( + train_dataset, + epochs=1, + + ) + + model = vq_trainer.vqvae + + ### Test save/load + weights_path = "trainer.weights.h5" + model.save_weights(weights_path) + model = load_model(weights_path) + + + ####### PLOT RECONSTRUCTION + + for batch in get_dataset(TRAIN_FOLDER): + img = batch # just fetch one image + break + + img_batched = tf.expand_dims(img, axis=0) # add batch dim + img_batched = tf.expand_dims(img_batched, axis=-1) # add channel dim + reconstruction = tf.squeeze(model(img_batched), axis=0) + + plot_reconstruction( + img, reconstruction + ) + + # save loss + with open(HISTORY_PATH, 'w+') as f: + json.dump(history.history['loss'], f) + +def load_model(path): + """ + loads the vqvae weights stored at the path into a new vqvae model + + to save weights for use by this helper. trainer.vqvae.save_weights(path) + """ + + new_vq_vae = vq_vae( + get_image_shape(), + LATENT_DIM, + CODEBOOK_SIZE + ) + + new_vq_vae.compile(optimizer="adam") + new_vq_vae.build(get_image_shape()) + new_vq_vae.load_weights(path) + + return new_vq_vae + +def test_load(): + model = load_model("trained_models/trainer.weights.h5") + model.summary() + +if __name__ == "__main__": + # main() + ####### PLOT RECONSTRUCTION + + model = load_model("trainer.weights.h5") + + i = 0 + for batch in get_dataset(TRAIN_FOLDER): + img = batch # just fetch one image + if i > 25: + break + i += 1 + + img_batched = tf.expand_dims(img, axis=0) # add batch dim + img_batched = tf.expand_dims(img_batched, axis=-1) # add channel dim + reconstruction = tf.squeeze(model(img_batched), axis=0) + + plot_reconstruction( + img, reconstruction + ) \ No newline at end of file