diff --git a/recognition/ISICs_Unet/README.md b/recognition/ISICs_Unet/README.md index f2c009212e..788ea17b79 100644 --- a/recognition/ISICs_Unet/README.md +++ b/recognition/ISICs_Unet/README.md @@ -1,52 +1,101 @@ -# Segmenting ISICs with U-Net +# Segment the ISICs data set with the U-net -COMP3710 Report recognition problem 3 (Segmenting ISICs data set with U-Net) solved in TensorFlow +## Project Overview +This project aim to solve the segmentation of skin lesian (ISIC2018 data set) using the U-net, with all labels having a minimum Dice similarity coefficient of 0.7 on the test set[Task 3]. -Created by Christopher Bailey (45576430) +## ISIC2018 +![ISIC example](imgs/example.jpg) -## The problem and algorithm -The problem solved by this program is binary segmentation of the ISICs skin lesion data set. Segmentation is a way to label pixels in an image according to some grouping, in this case lesion or non-lesion. This translates images of skin to masks representing areas of concern for skin lesions. +Skin Lesion Analysis towards Melanoma Detection -U-Net is a form of autoencoder where the downsampling path is expected to learn the features of the image and the upsampling path learns how to recreate the masks. Long skip connections between downpooling and upsampling layers are utilised to overcome the bottleneck in traditional autoencoders allowing feature representations to be recreated. +Task found in https://challenge2018.isic-archive.com/ -## How it works -A four layer padded U-Net is used, preserving skin features and mask resolution. The implementation utilises Adam as the optimizer and implements Dice distance as the loss function as this appeared to give quicker convergence than other methods (eg. binary cross-entropy). -The utilised metric is a Dice coefficient implementation. My initial implementation appeared faulty and was replaced with a 3rd party implementation which appears correct. 3 epochs was observed to be generally sufficient to observe Dice coefficients of 0.8+ on test datasets but occasional non-convergence was observed and could be curbed by increasing the number of epochs. Visualisation of predictions is also implemented and shows reasonable correspondence. Orange bandaids represent an interesting challenge for the implementation as presented. +## U-net +![UNet](imgs/uent.png) -### Training, validation and testing split -Training, validation and testing uses a respective 60:20:20 split, a commonly assumed starting point suggested by course staff. U-Net in particular was developed to work "with very few training images" (Ronneberger et al, 2015) The input data for this problem consists of 2594 images and masks. This split appears to provide satisfactory results. +U-net is one of the popular image segmentation architectures used mostly in biomedical purposes. The name UNet is because it’s architecture contains a compressive path and an expansive path which can be viewed as a U shape. This architecture is built in such a way that it could generate better results even for a less number of training data sets. -## Using the model -### Dependencies required -* Python3 (tested with 3.8) -* TensorFlow 2.x (tested with 2.3) -* glob (used to load filenames) -* matplotlib (used for visualisations, tested with 3.3) +## Data Set Structure -### Parameter tuning -The model was developed on a GTX 1660 TI (6GB VRAM) and certain values (notably batch size and image resolution) were set lower than might otherwise be ideal on more capable hardware. This is commented in the relevant code. +data set folder need to be stored in same directory with structure same as below +```bash +ISIC2018 + |_ ISIC2018_Task1-2_Training_Input_x2 + |_ ISIC_0000000 + |_ ISIC_0000001 + |_ ... + |_ ISIC2018_Task1_Training_GroundTruth_x2 + |_ ISIC_0000000_segmentation + |_ ISIC_0000001_segmentation + |_ ... +``` -### Running the model -The model is executed via the main.py script. +## Dice Coefficient -### Example output -Given a batch size of 1 and 3 epochs the following output was observed on a single run: -Era | Loss | Dice coefficient ---- | ---- | ---------------- -Epoch 1 | 0.7433 | 0.2567 -Epoch 2 | 0.3197 | 0.6803 -Epoch 3 | 0.2657 | 0.7343 -Testing | 0.1820 | 0.8180 +The Sørensen–Dice coefficient is a statistic used to gauge the similarity of two samples. +Further information in https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient -### Figure 1 - example visualisation plot -Skin images in left column, true mask middle, predicted mask right column -![Visualisation of predictions](visual.png) +## Dependencies -## References -Segments of code in this assignment were used from or based on the following sources: -1. COMP3710-demo-code.ipynb from Guest Lecture -1. https://www.tensorflow.org/tutorials/load_data/images -1. https://www.tensorflow.org/guide/gpu -1. Karan Jakhar (2019) https://medium.com/@karan_jakhar/100-days-of-code-day-7-84e4918cb72c +- python 3 +- tensorflow 2.1.0 +- pandas 1.1.4 +- numpy 1.19.2 +- matplotlib 3.3.2 +- scikit-learn 0.23.2 +- pillow 8.0.1 + + +## Usages + +- Run `train.py` for training the UNet on ISIC data. +- Run `evaluation.py` for evaluation and case present. + +## Advance + +- Modify `setting.py` for custom setting, such as different batch size. +- Modify `unet.py` for custom UNet, such as different kernel size. + +## Algorithm + +- data set: + - The data set we used is the training set of ISIC 2018 challenge data which has segmentation labels. + - Training: Validation: Test = 1660: 415: 519 = 0.64: 0.16 : 0.2 (Training: Test = 4: 1 and in Training, further split 4: 1 for Training: Validation) + - Training data augmentations: rescale, rotate, shift, zoom, grayscale +- model: + - Original UNet with padding which can keep the shape of input and output same. + - The first convolutional layers has 16 output channels. + - The activation function of all convolutional layers is ELU. + - Without batch normalization layers. + - The inputs is (384, 512, 1) + - The output is (384, 512, 1) after sigmoid activation. + - Optimizer: Adam, lr = 1e-4 + - Loss: dice coefficient loss + - Metrics: accuracy & dice coefficient + +## Results + +Evaluation dice coefficient is 0.805256724357605. + +plot of train/valid Dice coefficient: + +![img](imgs/train_and_valid_dice_coef.png) + +case present: + +![case](imgs/case%20present.png) + +## Reference +Manna, S. (2020). K-Fold Cross Validation for Deep Learning using Keras. [online] Medium. Available at: https://medium.com/the-owl/k-fold-cross-validation-in-keras-3ec4a3a00538 [Accessed 24 Nov. 2020]. + +zhixuhao (2020). zhixuhao/unet. [online] GitHub. Available at: https://github.com/zhixuhao/unet. + +GitHub. (n.d.). NifTK/NiftyNet. [online] Available at: https://github.com/NifTK/NiftyNet/blob/a383ba342e3e38a7ad7eed7538bfb34960f80c8d/niftynet/layer/loss_segmentation.py [Accessed 24 Nov. 2020]. + +Team, K. (n.d.). Keras documentation: Losses. [online] keras.io. Available at: https://keras.io/api/losses/#creating-custom-losses [Accessed 24 Nov. 2020]. + +262588213843476 (n.d.). unet.py. [online] Gist. Available at: https://gist.github.com/abhinavsagar/fe0c900133cafe93194c069fe655ef6e [Accessed 24 Nov. 2020]. + +Stack Overflow. (n.d.). python - Disable Tensorflow debugging information. [online] Available at: https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information [Accessed 24 Nov. 2020]. diff --git a/recognition/OASIS-Brain-StableDiffusion/dataset.py b/recognition/OASIS-Brain-StableDiffusion/dataset.py new file mode 100644 index 0000000000..6662218085 --- /dev/null +++ b/recognition/OASIS-Brain-StableDiffusion/dataset.py @@ -0,0 +1,32 @@ +from torch.utils.data import DataLoader +import torchvision + + + +def load_dataset(path, image_size=64, batch_size=64): + """ + Normalizes and loads images from a specified dataset into a dataloader + + Args: + path (str): path to the folder containing the dataset + image_size (int, optional): size, W, of the image (WxW). Defaults to 256. + batch_size (int, optional): batch size for the dataloader. Defaults to 64. + + Returns: + DataLoader: pyTorch dataloader of the dataset + """ + # define the transform used to normalize the input data + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize(image_size+round(0.25*image_size)), + torchvision.transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ] + ) + + # create the pyTorch dataset and dataloader + dataset = torchvision.datasets.ImageFolder(root=path, transform=transforms) + dataset_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + return dataset_loader diff --git a/recognition/OASIS-Brain-StableDiffusion/modules.py b/recognition/OASIS-Brain-StableDiffusion/modules.py new file mode 100644 index 0000000000..76e4bfa4d6 --- /dev/null +++ b/recognition/OASIS-Brain-StableDiffusion/modules.py @@ -0,0 +1,279 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F +import torchvision +import math + + +class ConvReluBlock(nn.Module): + """ + ConvReluBlock object consisting of a double convolution rectified + linear layer and a group normalization used at every level of the UNET model + """ + def __init__(self, dim_in, dim_out, residual_connection=False): + """ + Block class constructor to initialize the object + + Args: + dim_in (int): number of channels in the input image + dim_out (int): number of channels produced by the convolution + residual_connection (bool, optional): true if this block has a residual connect, false otherwise. Defaults to False. + """ + super(ConvReluBlock, self).__init__() + self.residual_connection = residual_connection + self.conv1 = nn.Conv2d(dim_in, dim_out, kernel_size=3, padding=1, bias=False) + self.conv2 = nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=1, bias=False) + self.relu = nn.ReLU() + self.gNorm = nn.GroupNorm(1, dim_out) + + def forward(self, x): + """ + Method to run an input tensor forward through the block + and returns the resulting output tensor + + Args: + x (Tensor): input tensor + + Returns: + Tensor: output tensor + """ + # Block 1 + x1 = self.conv1(x) + x2 = self.gNorm(x1) + x3 = self.relu(x2) + + # Block 2 + x4 = self.conv2(x3) + x5 = self.gNorm(x4) + + # Handle Residuals + if (self.residual_connection): + x6 = F.relu(x + x5) + else: + x6 = x5 + + return x6 + +class EncoderBlock(nn.Module): + """ + Encoder block consisting of a max pooling layer followed by 2 ConvReluBlocks + concatenated with the embedded position tensor + """ + def __init__(self, dim_in, dim_out, emb_dim=256): + """ + Encoder Block class constructor to initialize the object + + Args: + dim_in (int): number of channels in the input image. + dim_out (int): number of channels produced by the convolution. + emb_dim (int, optional): number of channels in the embedded layer. Defaults to 256. + """ + super(EncoderBlock, self).__init__() + self.encoder_block1 = ConvReluBlock(dim_in, dim_in, residual_connection=True) + self.encoder_block2 = ConvReluBlock(dim_in, dim_out) + self.pool = nn.MaxPool2d(2) + self.embedded_block = nn.Sequential(nn.ReLU(), nn.Linear(emb_dim, dim_out)) + + def forward(self, x, position): + """ + Method to run an input tensor forward through the encoder + and returns the resulting tensor + + Args: + x (Tensor): input tensor + position (Tensor): position tensor + + Returns: + Tensor: output tensor concatenated with the position tensor + """ + x = self.pool(x) + x = self.encoder_block1(x) + x = self.encoder_block2(x) + emb_x = self.embedded_block(position)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) + return x + emb_x + +class DecoderBlock(nn.Module): + """ + Decoder block consisting of an upsample layer followed by 2 ConvReluBlocks + concatenated with the embedded position tensor + """ + def __init__(self, dim_in, dim_out, emb_dim=256): + """ + Decoder Block class constructor to initialize the object + + Args: + dim_in (int): number of channels in the input image. + dim_out (int): number of channels produced by the convolution. + emb_dim (int, optional): number of channels in the embedded layer. Defaults to 256. + """ + super(DecoderBlock, self).__init__() + self.upSample_block = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.decoder_block1 = ConvReluBlock(dim_in, dim_in, residual_connection=True) + self.decoder_block2 = ConvReluBlock(dim_in, dim_out) + self.embedded_block = nn.Sequential(nn.ReLU(), nn.Linear(emb_dim, dim_out)) + self.embedded_block = nn.Sequential(nn.ReLU(), nn.Linear(emb_dim, dim_out)) + + def forward(self, x, skip_tensor, position): + """ + Method to run an input tensor forward through the decoder + and returns the output tensor + + Args: + x (Tensor): input tensor + skip_tensor (Tensor): tensor representing the skip connection from the encoder + position (Tensor): position tensor result of positional encoding + + Returns: + Tensor: output tensor concatenated with the position tensor + """ + + x = self.upSample_block(x) + x = torch.cat([skip_tensor, x], dim=1) + x = self.decoder_block1(x) + x = self.decoder_block2(x) + emb_x = self.embedded_block(position)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) + return emb_x + x + +class AttentionBlock(nn.Module): + """ + Transformer attention block to enhance some parts of the + data and diminish other parts + """ + def __init__(self, dims, dim_size): + """ + Attention Block class constructor to initialize the object + + Args: + dims (int): number of channels + dim_size (int): size of channels + """ + super(AttentionBlock, self).__init__() + self.dims = dims + self.dim_size = dim_size + self.mha_block = nn.MultiheadAttention(dims, 4, batch_first=True) + self.layer_norm_block = nn.LayerNorm([dims]) + self.a_layer = nn.Sequential( + nn.LayerNorm([dims]), + nn.Linear(dims, dims), + nn.ReLU(), + nn.Linear(dims, dims), + ) + + def forward(self, x): + """ + Method to run an input tensor forward through the attention block + and returns the output tensor + + Args: + x (Tensor): input tensor + + Returns: + Tensor: output tensor + """ + # Restructure the tensor for cross attention + x = x.view(-1, self.dims, self.dim_size * self.dim_size).swapaxes(1, 2) + x1 = self.layer_norm_block(x) + x2, _ = self.mha_block(x1, x1, x1) + x3 = x2 + x + x4 = self.a_layer(x3) + x3 + # Return the restructured attention tensor + return x4.swapaxes(2, 1).view(-1, self.dims, self.dim_size, self.dim_size) + +class UNet(nn.Module): + """ + Unet model consisting of a decoding block, an encoding block, + cross attention, and residual skip connections + """ + def __init__(self, dim_in_out=3, m_dim=64, pos_dim=256): + """ + Unet class constructor to initialize the object + + Args: + dim_in_out (int, optional): number of channels in the input image. Defaults to 3. + m_dim (int, optional): dimensional multiplier for generalization. Defaults to 64. + pos_dim (int, optional): positional dimension. Defaults to 256. + """ + super(UNet, self).__init__() + self.pos_dim = pos_dim + # Encoding part of the UNet # in --> out + self.in_layer = ConvReluBlock(dim_in_out, m_dim) # 3 --> 64 + self.encoder1 = EncoderBlock(m_dim, m_dim*2) # 64 --> 128 + self.attention1 = AttentionBlock(m_dim*2, int(m_dim/2)) # 128 --> 32 + self.encoder2 = EncoderBlock(m_dim*2, m_dim*4) # 128 --> 256 + self.attention2 = AttentionBlock(m_dim*4, int(m_dim/4)) # 256 --> 16 + self.encoder3 = EncoderBlock(m_dim*4, m_dim*4) # 256 --> 256 + self.attention3 = AttentionBlock(m_dim*4, int(m_dim/8)) # 256 --> 8 + + # Bottle neck of the UNet # in --> out + self.b1 = ConvReluBlock(m_dim*4, m_dim*8) # 256 --> 512 + self.b2 = ConvReluBlock(m_dim*8, m_dim*8) # 512 --> 512 + self.b3 = ConvReluBlock(m_dim*8, m_dim*4) # 512 --> 256 + + # Decoding part of the UNet # in --> out + self.decoder1 = DecoderBlock(m_dim*8, m_dim*2) # 512 --> 128 + self.attention4 = AttentionBlock(m_dim*2, int(m_dim/4)) # 128 --> 16 + self.decoder2 = DecoderBlock(m_dim*4, m_dim) # 256 --> 64 + self.attention5 = AttentionBlock(m_dim, int(m_dim/2)) # 64 --> 32 + self.decoder3 = DecoderBlock(m_dim*2, m_dim) # 128 --> 64 + self.attention6 = AttentionBlock(m_dim, m_dim) # 64 --> 64 + self.out_layer = nn.Conv2d(m_dim, dim_in_out, kernel_size=1) # 64 --> 3 + + def positional_embedding(self, position, dims): + """ + Calculate the positional tensor using transformer positional embedding + + Args: + position (Tensor): position tensor result of previous positional encoding + dims (int): number of channels + + Returns: + Tensor: positional embedded tensor + """ + inv_freq = 1.0 / ( + 10000 + ** (torch.arange(0, dims, 2, device="cuda").float() / dims) + ) + positional_embedding_a = torch.sin(position.repeat(1, dims // 2) * inv_freq) + positional_embedding_b = torch.cos(position.repeat(1, dims // 2) * inv_freq) + positional_embedding = torch.cat([positional_embedding_a, positional_embedding_b], dim=-1) + return positional_embedding + + def forward(self, x, position): + """ + Method to run an input tensor forward through the unet + and returns the output from all Unet layers + + Args: + x (Tensor): input tensor + position (Tensor): position tensor result of positional encoding + + Returns: + Tensor: output tensor + """ + position = position.unsqueeze(-1).type(torch.float) + position = self.positional_embedding(position, self.pos_dim) + + # Encoder forward step + x1 = self.in_layer(x) + x2 = self.encoder1(x1, position) + x2 = self.attention1(x2) + x3 = self.encoder2(x2, position) + x3 = self.attention2(x3) + x4 = self.encoder3(x3, position) + x4 = self.attention3(x4) + + # Bottle neck forward step + x4 = self.b1(x4) + x4 = self.b2(x4) + x4 = self.b3(x4) + + # Decoder forward step + x = self.decoder1(x4, x3, position) + x = self.attention4(x) + x = self.decoder2(x, x2, position) + x = self.attention5(x) + x = self.decoder3(x, x1, position) + x = self.attention6(x) + out = self.out_layer(x) + + return out \ No newline at end of file diff --git a/recognition/OASIS-Brain-StableDiffusion/predict.py b/recognition/OASIS-Brain-StableDiffusion/predict.py new file mode 100644 index 0000000000..3505ef5a72 --- /dev/null +++ b/recognition/OASIS-Brain-StableDiffusion/predict.py @@ -0,0 +1,107 @@ +import torch +import matplotlib.pyplot as plt +from datetime import datetime +from tqdm import tqdm +import numpy as np +from modules import * +from utils import * + + +def show_single_image(model): + """ + Create and show a single image from the model + + Args: + model (Module): PyTorch model to use + """ + model = model.to("cuda") + + # Set image to random noise + image = torch.randn((1, 3, 64, 64)).to("cuda") + + # Iteratively remove the noise + for i in tqdm(range(499, -1, -1)): + with torch.no_grad(): + image = remove_noise(image, torch.tensor([i]).to("cuda"), model) + + # convert image back into range [0, 255] + image = image[0].permute(1, 2, 0).detach().to("cpu") + # image = image*255 + # image = np.array(image, dtype=np.uint8) + + # Get unique time stamp + dt = datetime.now() + ts = round(datetime.timestamp(dt)) + + # Display using pyplot + plt.figure(figsize=(4,4)) + plt.axis("off") + plt.title("Image Created from Stable Diffusion Model") + plt.imshow(image) + plt.savefig('Image{}.png'.format(ts)) + plt.show() + + +def show_x_images(model, img_num=6): + """ + Create and show a x amount of images from the model + + Args: + model (Module): PyTorch model to use + img_num (int, optional): number of images to create. Defaults to 6. + """ + model = model.to("cuda") + + # Setup pyplot + plt.figure(figsize=(12, 3)) + plt.axis("off") + plt.title("Images Created from Stable Diffusion Model") + img_pos = 1 + + for i in range(1, img_num+1, 1): + image = torch.randn((1, 3, 64, 64)).to("cuda") + for j in tqdm(range(499, -1, -1)): + with torch.no_grad(): + image = remove_noise(image, torch.tensor([j]).to("cuda"), model) + + # convert image back into range [0, 255] + image = image[0].permute(1, 2, 0).detach().to("cpu") + image = image*255 + image = np.array(image, dtype=np.uint8) + sub = plt.subplot(1, img_num, img_pos) + sub.axis("off") + plt.imshow(image) + img_pos += 1 + + # Get unique time stamp + dt = datetime.now() + ts = round(datetime.timestamp(dt)) + plt.savefig('Image{}.png'.format(ts)) + plt.show() + + +def load_model(model_path): + """ + Load a PyTorch model from a saved model file + + Args: + model_path (string): file path for the model + + Returns: + Module: model PyTorch Module + """ + model = UNet() + # model.load_state_dict(torch.load(model_path)) + torch.load(model_path) + model.to("cuda") + model.eval() + return model + +def main(): + model_path = r".\model.pt" + model = load_model(model_path) + show_x_images(model) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/recognition/OASIS-Brain-StableDiffusion/readme.md b/recognition/OASIS-Brain-StableDiffusion/readme.md new file mode 100644 index 0000000000..4a3301c29f --- /dev/null +++ b/recognition/OASIS-Brain-StableDiffusion/readme.md @@ -0,0 +1,443 @@ +# OASIS Brain MRI Stable Diffusion + +##### Jonathan Allen (s4640736) + +## Model Description +Stable Diffusion is a deep learning image generation model. Unlike other diffusion models that denoise in pixel space, Stable Diffusion uses a Latent Diffusion Model (LDM) to add and remove noise in the latent space, making it more time and space efficient. An Encoder is used to transform images into the latent space, while a Decoder is used to transform the latent space back into pixel space. In the model, each denoising step is performed by a modified U-Net architecture consisting of cross attention and positional embedding. The cross attention increases the segmentation accuracy by enhancing prominent regions of the image and suppressing irrelevant regions. The positional embedding allows the model to know what noising or denoising step it is on, vital to add and remove the correct amount of noise to an image, as the basic U-Net architecture does not allow for this. For more information on the model used and for the mathematics behind it please read this [paper](https://arxiv.org/pdf/2112.10752.pdf). + +![image](https://miro.medium.com/max/1400/0*rW_y1kjruoT9BSO0.png) + +Source: https://arxiv.org/pdf/2112.10752.pdf + +For the model implemented here, the conditioning part of the Stable Diffusion Model was intentionally dismissed as unnecessary since the task is to only recreate one specific type of data set. The conditioning would be needed if we were trying to recreate many types of data from a given input. For example, creating an image from text, in this case, text is the conditioning element. + +## How Stable Diffusion Works +At a high level, diffusion models work by progressively altering the training data by adding noise, and then "learn" to recover the data by reversing this noising process. After training, this reversing of the noising process can be applied to random noise, hence generating new data. In other words, a complete diffusion model can produce coherent images from initial random noise. + +In this case specifically, the model is trained on OASIS brain MRI data, it adds a specific amount of noise to the image, and then iteratively learns how to remove the added noise in hopes to recover the altered training image. Once trained, it can be given a similar photo containing nothing but random noise, and then using the learned denoising process, generate a photo-realistic image resembling the training data. + +This image shows a visual representation of the model noising an image and then learning to denoise it by reversing the process. Noising of the image going left to right on the first row and denoising of the image going right to left on the second row. +![Example Noising Process](https://lh3.googleusercontent.com/432gw-wUaTSikRtRp2IjoIRxM_xLYhy06LXcUYfHmVZoGJfWGl88HX5DO4jUxxhaZdPY_yDsKymTyHqO3oNz5vVv71poNJAwkbaYXtStpA5XyjPTqjvA3NNJK5rJndkgru4f9DPfqdqwKQuazuND-yWpn0uplZ-6mUfboiLh1BNEu1a92Pxm83gDtYfhr7chxzZW1ibgPp6dJ8G75yWy26SxjA6n9hgSDpqQgQj-QmRZURf7zcXnGbPMvk_1Je-uB2nzxIfswWVyb7isxdBKU75NzyV-a6zNLdZY9CDEgU50jzrCYeAA8_mjWNFDHsG_kyQgsCbAcdt4Logvk-d-ipqi12LRE83XsfOWopI9-Bs9FDN0eDBndNTPWh_PsGzaw1ZyAn-tJSzmtRjz3DQnnQ3J34BvFiYkZyPSBErDLvAYemeIphUZ-u7qxlbgi9HmkOU_g4AtMEc637LuMhD8bQN8u9y2cA74giWEce_Xw8E62oR4oowKkKCWWLw6HFs_JoLAAb4NJ6eJs_2JDOvDcKVVyNt07_mWZdNx2xvB2bjEoKIf-s4iBMT0q0RcxqUfhZk8ItM9nRuEkrx1DuGc1BuDWLjsfSUIZ5UHRgRlO11G6-zHhmvPUyAYnguS3k6bs8rTrMmGf6Fu6zWIydvxEUtsfJ97ZsRbmDCe1pbq4dVF-PMLoeTAKQagh0iTd6gvlsHijNsq2erqU0tSiMyVlGOk8tsZs5hVlFDJXCxaMXQpi6Mbpb_ErI-azmB0-CUi8mAdOphz2AKSp_0dMTgyIn25Gc3JI8BFerIVYSee2zjMYPb9NxNskS77yNRyPNCMWKAu4Ogv4zQihrPltHwo0kvz82Fcz6_XjRBy3NOh6NvyRBNujKz24_90iKvrg8wxNo6l4v5Z93MXhv70ctW3d8QPR1zL_I145aBp1A=w642-h319-no?authuser=0) +## Project Structure +This project is built using Python3.10.6 and PyTorch1.12.1. +## Data Pipeline `dataset.py` +### Data Loading +The OASIS Brain MRI dataset is the data this model was trained against. +This data can be found [here](https://cloudstor.aarnet.edu.au/plus/s/tByzSZzvvVh0hZA). +Example MRI slice from the dataset + +![Brain MRI](https://lh3.googleusercontent.com/pw/AL9nZEViraVfAx4nNjNFk7ga3r2QBN5zKUvgXMg7C-OvQLNKJN_mnTjKSrS4PmHYn5VZlt0ZUenfr15Bym4h08bWUF6XhivR0WwOXxGN1IJM2C7_oxYpSskmnNR9tzFdSVWPNmuhdTFF24qV4DDC4qrnkUx2=s256-no?authuser=0) + +Source: [OASIS Brain MRI dataset](https://cloudstor.aarnet.edu.au/plus/s/tByzSZzvvVh0hZA) + +To load this data into the program, I used a function called `load_dataset` found in `dataset.py`. This function loads data from a folder, normalizes it, and puts it in a pytorch dataloader for ease of use. +Two completely separate folders were used to store the training data and test validation data, `training_data` and `test_data`. The `training_data` consisted of 8,406 images, while `test_data` contains 544 images. Separating this data into two different folders ensured that the model never saw the test data until it was validated. +Although the data looks grayscale, I did not transform the data into grayscale for training in hopes to keep the model open to train coloured images in the future. + +### Preprocessing/Normalization +All the input data needs to be preprocessed and normalized to be useful for the model. This normalization is also done in the `load_dataset` function found in `dataset.py`. This is done using `torchvision`and consists of four lines. +The first line resizes the image to be a quarter size larger to allow for random cropping +`torchvision.transforms.Resize(image_size+round(0.25*image_size))` +The second line randomly crops the image back down to its intended size but slightly off center +`torchvision.transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0))` +The third line converts it to a PyTorch tensor +`torchvision.transforms.ToTensor()` +The fourth line normalizes it using the mean and std +`torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))` +This preprocessing and normalization is in place to ensure the model receives data to generalize not just remember. + +## Modules `modules.py` +As discussed above a modified U-Net was used to denoise the data in latent space. This modified U-Net has cross attention blocks along with positional embedding and will discussed later. + +A general U-Net is an encoder and decoder convolutional network with skip connections between each opposing block. It was originally created to produce image segmentation mapping which makes it the golden choice to use in a stable diffusion model. + +![Unet Diagram](https://miro.medium.com/max/1200/1*f7YOaE4TWubwaFF7Z1fzNw.png) + +Source: https://towardsdatascience.com/unet-line-by-line-explanation-9b191c76baf5 + +The U-Net implemented for this model is slightly different and consists of the following blocks. +- `ConvReluBlock` - object consisting of a double convolution rectified linear layer and a group normalization +- `EncoderBlock` - block consisting of a max pooling layer followed by 2 ConvReluBlocks +concatenated with the embedded position tensor +- `DecoderBlock` - Decoder block consisting of an upsample layer followed by 2 ConvReluBlocks +concatenated with the embedded position tensor +- `AttentionBlock` - Transformer attention block to enhance some parts of the data and diminish other parts +- `UNet` - Unet model consisting of a decoding block, an encoding block, cross attention, and residual skip connections along with positional encoding + +These blocks used to implement the network can be found in the `modules.py` file. +To better understand the model, below are annotated forward steps for each block. +### ConvReluBlock +```python +# Block 1 +x1 = self.conv1(x) #nn.Conv2d() +x2 = self.gNorm(x1) #nn.GroupNorm() +x3 = self.relu(x2) #nn.ReLU() +# Block 2 +x4 = self.conv2(x3) #nn.Conv2d() +x5 = self.gNorm(x4) #nn.GroupNorm() +# Handle Residuals +if (self.residual_connection): + x6 = F.relu(x + x5) +else: + x6 = x5 +return x6 +``` +### EncoderBlock +```python +x = self.pool(x) #nn.MaxPool2d() +x = self.encoder_block1(x) #ConvReluBlock() +x = self.encoder_block2(x) #ConvReluBlock() +emb_x = self.embedded_block(position)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) #nn.ReLU() followed by nn.Linear() +return x + emb_x #positional embedding, emb_x, is added in at every step +``` + +### DecoderBlock +```python +x = self.upSample_block(x) #nn.Upsample() +x = torch.cat([skip_tensor, x], dim=1) #Adding in the skip connections from encoder +x = self.decoder_block1(x) #ConvReluBlock() +x = self.decoder_block2(x) #ConvReluBlock() +emb_x = self.embedded_block(position)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) #nn.ReLU() followed by nn.Linear() +return emb_x + x #positional embedding, emb_x, is added in at every step +``` +### UNet +```python +position = position.unsqueeze(-1).type(torch.float) +position = self.positional_embedding(position, self.pos_dim) +# Encoder forward step # in --> out (Tensor Size) +x1 = self.in_layer(x) #ConvReluBlock() # 3 --> 64 +x2 = self.encoder1(x1, position) #EncoderBlock() # 64 --> 128 +x2 = self.attention1(x2) #AttentionBlock() # 128 --> 32 +x3 = self.encoder2(x2, position) #EncoderBlock() # 128 --> 256 +x3 = self.attention2(x3) #AttentionBlock() # 256 --> 16 +x4 = self.encoder3(x3, position) #EncoderBlock() # 256 --> 256 +x4 = self.attention3(x4) #AttentionBlock() # 256 --> 8 + +# Bottle neck forward step # in --> out (Tensor Size) +x4 = self.b1(x4) #ConvReluBlock() # 256 --> 512 +x4 = self.b2(x4) #ConvReluBlock() # 512 --> 512 +x4 = self.b3(x4) #ConvReluBlock() # 512 --> 256 + +# Decoder forward step # in --> out (Tensor Size) +x = self.decoder1(x4, x3, position) #DecoderBlock() # 512 --> 128 +x = self.attention4(x) #AttentionBlock() # 128 --> 16 +x = self.decoder2(x, x2, position) #DecoderBlock() # 256 --> 64 +x = self.attention5(x) #AttentionBlock() # 64 --> 32 +x = self.decoder3(x, x1, position) #DecoderBlock() # 128 --> 64 +x = self.attention6(x) #AttentionBlock() # 64 --> 64 +out = self.out_layer(x) #nn.Conv2d() # 64 --> 3 + +return out +``` + +## Training `train.py` +The training loop for this network is a standard PyTorch training loop consisting of the main model, optimizer, and loss function. +### Main Model +The main model for this training loop is the U-Net outlined above in the modules section. +### Optimizer +The optimizer chosen for this training loop was the same one mentioned in the paper due to its generalize well. This optimizer is Adam from `torch.optim.adam`. +### Loss Function +The loss function chosen for this training loop is Mean Squared Error (squared L2 norm). This loss function was chosen since the network is extremely connected and noisy, it provides some generalization and smoothing. The loss function can be found at `torch.nn.MSELoss`. +### Training Loop +```python +for epoch in range(epochs): + epoch_loss = 0 + for idx, (data, _) in enumerate(tqdm(train_dataloader)): + data = data.to(device) + position = get_sample_pos(data.shape[0]).to(device) + noisy_x, noise = add_noise(data, position) + predicted_noise = model(noisy_x, position) + loss = loss_fn(noise, predicted_noise) + optimizer.zero_grad() + loss.backward() + epoch_loss += loss.item() + optimizer.step() + tracked_loss.append([epoch_loss / dataloader_length]) + print("Current Loss ==> {}".format(epoch_loss/dataloader_length)) + test_loss.append(test_model(model, test_path, batch_size, device)) +``` +The last line of the training loop calls the validation loop for that epoch to gather data on the test data to ensure the mode is not over-fitting. This is then saved to a csv and analyzed with Microsoft Excel. + +### Validation Loop +```python +for idx, (data, _) in enumerate(tqdm(test_dataloader)): + data = data.to(device) + position = get_sample_pos(data.shape[0]).to(device) + noisy_x, noise = add_noise(data, position) + predicted_noise = model(noisy_x, position) + loss = loss_fn(noise, predicted_noise) + running_loss += loss.item() +``` +### Training Results +#### Graph showing the running loss of the training set in the training loop +![Training Loss Vs Epochs](https://lh3.googleusercontent.com/pw/AL9nZEXH2I2U1lkr2GSYLzaDYCpROtEi_1OBWQLEEBIUu50t-2Rl5OBAeSYB05HEHLiOlItM_UJbGPEldyLEzeI_46pKSp8fuvKqYB1iA5NfXHwDZUqOyJlrYPMAtXYspKBMeeKLyjV9KHgCMXu5Rpgl3aCJ=w752-h452-no?authuser=0) +#### Graph showing the running loss of the test set in the validation loop +![Testing Loss vs Epoch](https://lh3.googleusercontent.com/pw/AL9nZEUevH6b6bYM2b4t073QSLS3iTE9O2KyasB9qhwcNqdSRcER5fsRassBdCob0oDd1uuZ7WHMSpzEigIQY1Jd_HyiAT6pnKFMu_tLvZwFHt_XvkD1ZTRspbIA4_cU-ci_1FW0_52kIis50unYYOygVD8X=w751-h452-no?authuser=0) +Showing only the first 30 Epochs as the loss flattens off and platues for the rest of the epochs. +### Training Your Own Model +To train your own stable diffusion model. Ensure the hyperparameters meet your specification in `main()` in `train.py` (below is the default example one). +```python +def main(): + #hyperparameters + device = "cuda" + lr = 3e-4 + train_path = r".\OASIS-Brain-Data\training_data" + test_path = r".\OASIS-Brain-Data\test_data" + model = UNet().to(device) + batch_size = 12 + epochs = 200 +``` +Then run `train.py` in the terminal with the command `python train.py`. + +## Results `predict.py` +Once a model is trained, `predict.py` can be used to load that model using `load_model()`. +After the model is loaded, it can be used to generate images from noise using the functions `show_single_image()` and `show_x_images()`. The path to that model must be specified in the `main` method. By default `predict.py` generates 6 images in a row and saves it. + +### Images Generated From Stable Diffusion Model +Here are six Brain MRIs generated from this stable diffusion model + + + + + +To generate your own images, just run `python predict.py`in the terminal. Remember a model has to be trained, saved, and then its path has to be referenced appropriately in `predict.py`. + +## References + +- [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) +- [Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/abs/2105.05233) +- [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) +- [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) +- [The Annotated Diffusion Model](https://huggingface.co/blog/annotated-diffusion) + + +## Dependencies +``` +# Name Version Build Channel +absl-py 1.2.0 pyhd8ed1ab_0 conda-forge +aiohttp 3.8.1 py310he2412df_1 conda-forge +aiosignal 1.2.0 pyhd8ed1ab_0 conda-forge +argon2-cffi 21.3.0 pyhd8ed1ab_0 conda-forge +argon2-cffi-bindings 21.2.0 py310he2412df_2 conda-forge +asttokens 2.0.8 pyhd8ed1ab_0 conda-forge +async-timeout 4.0.2 pyhd8ed1ab_0 conda-forge +attrs 22.1.0 pyh71513ae_1 conda-forge +backcall 0.2.0 pyh9f0ad1d_0 conda-forge +backports 1.0 py_2 conda-forge +backports.functools_lru_cache 1.6.4 pyhd8ed1ab_0 conda-forge +beautifulsoup4 4.11.1 pyha770c72_0 conda-forge +blas 2.116 mkl conda-forge +blas-devel 3.9.0 16_win64_mkl conda-forge +bleach 5.0.1 pyhd8ed1ab_0 conda-forge +blinker 1.4 py_1 conda-forge +brotli 1.0.9 h8ffe710_7 conda-forge +brotli-bin 1.0.9 h8ffe710_7 conda-forge +brotlipy 0.7.0 py310he2412df_1004 conda-forge +bzip2 1.0.8 h8ffe710_4 conda-forge +c-ares 1.18.1 h8ffe710_0 conda-forge +ca-certificates 2022.9.24 h5b45459_0 conda-forge +cachetools 5.2.0 pyhd8ed1ab_0 conda-forge +certifi 2022.9.24 pyhd8ed1ab_0 conda-forge +cffi 1.15.1 py310hcbf9ad4_0 conda-forge +charset-normalizer 2.1.1 pyhd8ed1ab_0 conda-forge +click 8.1.3 py310h5588dad_0 conda-forge +colorama 0.4.5 pyhd8ed1ab_0 conda-forge +cryptography 37.0.1 py310h21b164f_0 +cudatoolkit 11.6.0 hc0ea762_10 conda-forge +cycler 0.11.0 pyhd8ed1ab_0 conda-forge +debugpy 1.6.3 py310h8a704f9_0 conda-forge +decorator 5.1.1 pyhd8ed1ab_0 conda-forge +defusedxml 0.7.1 pyhd8ed1ab_0 conda-forge +entrypoints 0.4 pyhd8ed1ab_0 conda-forge +executing 1.0.0 pyhd8ed1ab_0 conda-forge +flit-core 3.7.1 pyhd8ed1ab_0 conda-forge +fonttools 4.37.1 py310he2412df_0 conda-forge +freetype 2.12.1 h546665d_0 conda-forge +frozenlist 1.3.1 py310he2412df_0 conda-forge +gettext 0.19.8.1 ha2e2712_1008 conda-forge +glib 2.72.1 h7755175_0 conda-forge +glib-tools 2.72.1 h7755175_0 conda-forge +google-auth 2.11.0 pyh6c4a22f_0 conda-forge +google-auth-oauthlib 0.4.6 pyhd8ed1ab_0 conda-forge +grpc-cpp 1.48.1 h535cfc9_1 conda-forge +grpcio 1.48.1 py310hd8b4215_1 conda-forge +gst-plugins-base 1.20.3 h001b923_1 conda-forge +gstreamer 1.20.3 h6b5321d_1 conda-forge +icu 70.1 h0e60522_0 conda-forge +idna 3.3 pyhd8ed1ab_0 conda-forge +importlib-metadata 4.11.4 py310h5588dad_0 conda-forge +importlib_resources 5.9.0 pyhd8ed1ab_0 conda-forge +intel-openmp 2022.1.0 h57928b3_3787 conda-forge +ipykernel 6.15.2 pyh025b116_0 conda-forge +ipython 8.5.0 pyh08f2357_1 conda-forge +ipython_genutils 0.2.0 py_1 conda-forge +jedi 0.18.1 pyhd8ed1ab_2 conda-forge +jinja2 3.1.2 pyhd8ed1ab_1 conda-forge +joblib 1.1.0 pyhd8ed1ab_0 conda-forge +jpeg 9e h8ffe710_2 conda-forge +jsonschema 4.15.0 pyhd8ed1ab_0 conda-forge +jupyter_client 7.3.5 pyhd8ed1ab_0 conda-forge +jupyter_contrib_core 0.4.0 pyhd8ed1ab_0 conda-forge +jupyter_contrib_nbextensions 0.5.1 pyhd8ed1ab_2 conda-forge +jupyter_core 4.11.1 py310h5588dad_0 conda-forge +jupyter_highlight_selected_word 0.2.0 py310h5588dad_1005 conda-forge +jupyter_latex_envs 1.4.6 pyhd8ed1ab_1002 conda-forge +jupyter_nbextensions_configurator 0.4.1 pyhd8ed1ab_2 conda-forge +jupyterlab_pygments 0.2.2 pyhd8ed1ab_0 conda-forge +kiwisolver 1.4.4 py310h476a331_0 conda-forge +krb5 1.19.3 h1176d77_0 conda-forge +lcms2 2.12 h2a16943_0 conda-forge +lerc 4.0.0 h63175ca_0 conda-forge +libabseil 20220623.0 cxx17_h1a56200_4 conda-forge +libblas 3.9.0 16_win64_mkl conda-forge +libbrotlicommon 1.0.9 h8ffe710_7 conda-forge +libbrotlidec 1.0.9 h8ffe710_7 conda-forge +libbrotlienc 1.0.9 h8ffe710_7 conda-forge +libcblas 3.9.0 16_win64_mkl conda-forge +libclang 14.0.6 default_h77d9078_0 conda-forge +libclang13 14.0.6 default_h77d9078_0 conda-forge +libdeflate 1.13 h8ffe710_0 conda-forge +libffi 3.4.2 h8ffe710_5 conda-forge +libglib 2.72.1 h3be07f2_0 conda-forge +libiconv 1.16 he774522_0 conda-forge +liblapack 3.9.0 16_win64_mkl conda-forge +liblapacke 3.9.0 16_win64_mkl conda-forge +libogg 1.3.4 h8ffe710_1 conda-forge +libpng 1.6.37 h1d00b33_4 conda-forge +libprotobuf 3.21.5 h12be248_3 conda-forge +libsodium 1.0.18 h8d14728_1 conda-forge +libsqlite 3.39.3 hcfcfb64_0 conda-forge +libtiff 4.4.0 h92677e6_3 conda-forge +libuv 1.44.2 h8ffe710_0 conda-forge +libvorbis 1.3.7 h0e60522_0 conda-forge +libwebp-base 1.2.4 h8ffe710_0 conda-forge +libxcb 1.13 hcd874cb_1004 conda-forge +libxml2 2.9.14 hf5bbc77_4 conda-forge +libxslt 1.1.35 h34f844d_0 conda-forge +libzlib 1.2.12 h8ffe710_2 conda-forge +lxml 4.9.1 py310he2412df_0 conda-forge +m2w64-gcc-libgfortran 5.3.0 6 conda-forge +m2w64-gcc-libs 5.3.0 7 conda-forge +m2w64-gcc-libs-core 5.3.0 7 conda-forge +m2w64-gmp 6.1.0 2 conda-forge +m2w64-libwinpthread-git 5.0.0.4634.697f757 2 conda-forge +markdown 3.4.1 pyhd8ed1ab_0 conda-forge +markupsafe 2.1.1 py310he2412df_1 conda-forge +matplotlib 3.5.3 py310h5588dad_2 conda-forge +matplotlib-base 3.5.3 py310h7329aa0_2 conda-forge +matplotlib-inline 0.1.6 pyhd8ed1ab_0 conda-forge +mistune 2.0.4 pyhd8ed1ab_0 conda-forge +mkl 2022.1.0 h6a75c08_874 conda-forge +mkl-devel 2022.1.0 h57928b3_875 conda-forge +mkl-include 2022.1.0 h6a75c08_874 conda-forge +msys2-conda-epoch 20160418 1 conda-forge +multidict 6.0.2 py310he2412df_1 conda-forge +munkres 1.1.4 pyh9f0ad1d_0 conda-forge +nb_conda_kernels 2.3.1 py310h5588dad_1 conda-forge +nbclient 0.6.7 pyhd8ed1ab_0 conda-forge +nbconvert 7.0.0 pyhd8ed1ab_0 conda-forge +nbconvert-core 7.0.0 pyhd8ed1ab_0 conda-forge +nbconvert-pandoc 7.0.0 pyhd8ed1ab_0 conda-forge +nbformat 5.4.0 pyhd8ed1ab_0 conda-forge +nest-asyncio 1.5.5 pyhd8ed1ab_0 conda-forge +notebook 6.4.12 pyha770c72_0 conda-forge +numpy 1.23.2 py310h8a5b91a_0 conda-forge +oauthlib 3.2.1 pyhd8ed1ab_0 conda-forge +openjpeg 2.5.0 hc9384bd_1 conda-forge +openssl 1.1.1q h8ffe710_0 conda-forge +packaging 21.3 pyhd8ed1ab_0 conda-forge +pandas 1.4.4 pypi_0 pypi +pandoc 2.19.2 h57928b3_0 conda-forge +pandocfilters 1.5.0 pyhd8ed1ab_0 conda-forge +parso 0.8.3 pyhd8ed1ab_0 conda-forge +pathlib 1.0.1 py310h5588dad_6 conda-forge +pcre 8.45 h0e60522_0 conda-forge +pickleshare 0.7.5 py_1003 conda-forge +pillow 9.2.0 py310h52929f7_2 conda-forge +pip 22.2.2 pyhd8ed1ab_0 conda-forge +pkgutil-resolve-name 1.3.10 pyhd8ed1ab_0 conda-forge +ply 3.11 py_1 conda-forge +prometheus_client 0.14.1 pyhd8ed1ab_0 conda-forge +prompt-toolkit 3.0.31 pyha770c72_0 conda-forge +protobuf 3.19.5 pypi_0 pypi +psutil 5.9.2 py310h8d17308_0 conda-forge +pthread-stubs 0.4 hcd874cb_1001 conda-forge +pure_eval 0.2.2 pyhd8ed1ab_0 conda-forge +pyasn1 0.4.8 py_0 conda-forge +pyasn1-modules 0.2.8 pypi_0 pypi +pycparser 2.21 pyhd8ed1ab_0 conda-forge +pygments 2.13.0 pyhd8ed1ab_0 conda-forge +pyjwt 2.4.0 pyhd8ed1ab_0 conda-forge +pyopenssl 22.0.0 pyhd8ed1ab_0 conda-forge +pyparsing 3.0.9 pyhd8ed1ab_0 conda-forge +pyqt 5.15.7 py310hbabf5d4_0 conda-forge +pyqt5-sip 12.11.0 py310h8a704f9_0 conda-forge +pyrsistent 0.18.1 py310he2412df_1 conda-forge +pysocks 1.7.1 pyh0701188_6 conda-forge +python 3.10.6 h9a09f29_0_cpython conda-forge +python-dateutil 2.8.2 pyhd8ed1ab_0 conda-forge +python-fastjsonschema 2.16.1 pyhd8ed1ab_0 conda-forge +python_abi 3.10 2_cp310 conda-forge +pytorch 1.12.1 py3.10_cuda11.6_cudnn8_0 pytorch +pytorch-model-summary 0.1.1 py_0 conda-forge +pytorch-mutex 1.0 cuda pytorch +pytz 2022.2.1 pypi_0 pypi +pyu2f 0.1.5 pyhd8ed1ab_0 conda-forge +pywin32 303 py310he2412df_0 conda-forge +pywinpty 2.0.7 py310h00ffb61_0 conda-forge +pyyaml 6.0 py310he2412df_4 conda-forge +pyzmq 23.2.1 py310h73ada01_0 conda-forge +qt-main 5.15.4 h467ea89_2 conda-forge +re2 2022.06.01 h0e60522_0 conda-forge +requests 2.28.1 pyhd8ed1ab_1 conda-forge +requests-oauthlib 1.3.1 pyhd8ed1ab_0 conda-forge +rsa 4.9 pyhd8ed1ab_0 conda-forge +scikit-learn 1.1.2 py310h3a564e9_0 conda-forge +scipy 1.9.1 py310h578b7cb_0 conda-forge +send2trash 1.8.0 pyhd8ed1ab_0 conda-forge +setuptools 65.3.0 pyhd8ed1ab_1 conda-forge +sip 6.6.2 py310h8a704f9_0 conda-forge +six 1.16.0 pyh6c4a22f_0 conda-forge +soupsieve 2.3.2.post1 pyhd8ed1ab_0 conda-forge +sqlite 3.39.3 hcfcfb64_0 conda-forge +stack_data 0.5.0 pyhd8ed1ab_0 conda-forge +tbb 2021.5.0 h91493d7_2 conda-forge +tensorboard 2.10.1 pyhd8ed1ab_0 conda-forge +tensorboard-data-server 0.6.1 pypi_0 pypi +tensorboard-plugin-wit 1.8.1 pyhd8ed1ab_0 conda-forge +terminado 0.15.0 py310h5588dad_0 conda-forge +threadpoolctl 3.1.0 pyh8a188c0_0 conda-forge +tinycss2 1.1.1 pyhd8ed1ab_0 conda-forge +tk 8.6.12 h8ffe710_0 conda-forge +toml 0.10.2 pyhd8ed1ab_0 conda-forge +torch-summary 1.4.5 pypi_0 pypi +torch-tb-profiler 0.4.0 pypi_0 pypi +torchaudio 0.12.1 py310_cu116 pytorch +torchvision 0.13.1 py310_cu116 pytorch +tornado 6.2 py310he2412df_0 conda-forge +tqdm 4.64.1 pyhd8ed1ab_0 conda-forge +traitlets 5.3.0 pyhd8ed1ab_0 conda-forge +typing-extensions 4.3.0 hd8ed1ab_0 conda-forge +typing_extensions 4.3.0 pyha770c72_0 conda-forge +tzdata 2022c h191b570_0 conda-forge +ucrt 10.0.20348.0 h57928b3_0 conda-forge +unicodedata2 14.0.0 py310he2412df_1 conda-forge +urllib3 1.26.11 pyhd8ed1ab_0 conda-forge +vc 14.2 hb210afc_7 conda-forge +vs2015_runtime 14.29.30139 h890b9b1_7 conda-forge +wcwidth 0.2.5 pyh9f0ad1d_2 conda-forge +webencodings 0.5.1 py_1 conda-forge +werkzeug 2.2.2 pyhd8ed1ab_0 conda-forge +wheel 0.37.1 pyhd8ed1ab_0 conda-forge +win_inet_pton 1.1.0 py310h5588dad_4 conda-forge +winpty 0.4.3 4 conda-forge +xorg-libxau 1.0.9 hcd874cb_0 conda-forge +xorg-libxdmcp 1.1.3 hcd874cb_0 conda-forge +xz 5.2.6 h8d14728_0 conda-forge +yaml 0.2.5 h8ffe710_2 conda-forge +yarl 1.7.2 py310he2412df_2 conda-forge +zeromq 4.3.4 h0e60522_1 conda-forge +zipp 3.8.1 pyhd8ed1ab_0 conda-forge +zlib 1.2.12 h8ffe710_2 conda-forge +zstd 1.5.2 h7755175_4 conda-forge +``` diff --git a/recognition/OASIS-Brain-StableDiffusion/train.py b/recognition/OASIS-Brain-StableDiffusion/train.py new file mode 100644 index 0000000000..8a9e4fc7d0 --- /dev/null +++ b/recognition/OASIS-Brain-StableDiffusion/train.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +from tqdm.auto import tqdm +from torch.optim import Adam +from modules import * +from dataset import * +from utils import * + +def train(device, lr, train_path, test_path, model, epochs, batch_size): + """ + Training loop to train the model + + Args: + device (string): device to train on + lr (float): learning rate of the model + train_path (string): path to training data + test_path (string): path to test data + model (Module): network model + epochs (int): number of epochs to run + batch_size (int): batch size + """ + train_dataloader = load_dataset(train_path, batch_size=batch_size) + dataloader_length = len(train_dataloader) + + model = model.to(device) + optimizer = Adam(model.parameters(), lr=lr) + loss_fn = nn.MSELoss() + tracked_loss = [] + test_loss = [] + + for epoch in range(epochs): + epoch_loss = 0 + for idx, (data, _) in enumerate(tqdm(train_dataloader)): + data = data.to(device) + position = get_sample_pos(data.shape[0]).to(device) + noisy_x, noise = add_noise(data, position) + predicted_noise = model(noisy_x, position) + loss = loss_fn(noise, predicted_noise) + + optimizer.zero_grad() + loss.backward() + epoch_loss += loss.item() + optimizer.step() + + tracked_loss.append([epoch_loss / dataloader_length]) + print("Current Loss ==> {}".format(epoch_loss/dataloader_length)) + test_loss.append(test_model(model, test_path, batch_size, device)) + + torch.save(model.state_dict(), "model") + save_loss_data(tracked_loss, test_loss) + +def test_model(model, test_path, batch_size, device): + """ + Test loop to test the model against test data + + Args: + model (Module): model to test + test_path (string): path to the test data + batch_size (int): batch size + device (string): device to use + + Returns: + List: index of loss of the model against the test data + """ + test_dataloader = load_dataset(test_path, batch_size=batch_size) + dataloader_length = len(test_dataloader) + + model = model.to(device) + + loss_fn = nn.MSELoss() + running_loss = 0 + for idx, (data, _) in enumerate(tqdm(test_dataloader)): + data = data.to(device) + position = get_sample_pos(data.shape[0]).to(device) + noisy_x, noise = add_noise(data, position) + predicted_noise = model(noisy_x, position) + loss = loss_fn(noise, predicted_noise) + running_loss += loss.item() + + print("Test Loss ==> {}".format(running_loss / dataloader_length)) + return [running_loss / dataloader_length] + +def main(): + #hyperparameters + device = "cuda" + lr = 3e-4 + train_path = r".\OASIS-Brain-Data\training_data" + test_path = r".\OASIS-Brain-Data\test_data" + model = UNet().to(device) + batch_size = 12 + epochs = 200 + + # start training + train(device, lr, train_path, test_path, model, epochs, batch_size) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/recognition/OASIS-Brain-StableDiffusion/utils.py b/recognition/OASIS-Brain-StableDiffusion/utils.py new file mode 100644 index 0000000000..e39ef9fd4c --- /dev/null +++ b/recognition/OASIS-Brain-StableDiffusion/utils.py @@ -0,0 +1,101 @@ +from time import time +import torch +import torch.nn.functional as F +import csv + + +def get_noise_cadence(): + """ + Generates a tensor describing the cadence of noise addition timestamps + + Returns: + Tensor: noise to be added + """ + return torch.linspace(1e-4, 0.02, 1000) + +def add_noise(x, pos): + """ + Adds noise to tensor x through a noising timeline + + Args: + x (Tensor): Tensor representation of an image to noise + pos (int): Position in the noising process to sample beta + + Returns: + Tensor: input tensor with noise added to it + """ + beta = get_noise_cadence().to("cuda") + alpha = 1.0 - beta + alpha_cumprod = torch.cumprod(alpha, dim=0) + sqrt_alpha_cumprod = torch.sqrt(alpha_cumprod[pos])[:, None, None, None] + sqrt_minus_alpha_cumprod = torch.sqrt(1 - alpha_cumprod[pos])[:, None, None, None] + E = torch.randn_like(x) + return sqrt_alpha_cumprod * x + sqrt_minus_alpha_cumprod * E, E + +def remove_noise(img, timestep, model): + beta = get_noise_cadence().to("cuda") + alpha = 1.0 - beta + alpha_cumprod = torch.cumprod(alpha, dim=0) + alpha_cumprod_rev = F.pad(alpha_cumprod[:-1], (1, 0), value=1.0) + sqrt_alpha_reciprocal = torch.sqrt(1.0 / alpha) + sqrt_minus_alpha_cumprod = torch.sqrt(1.0 - alpha_cumprod) + sqrt_minus_alpha_cumprod_x = extract_index(sqrt_minus_alpha_cumprod, timestep, img.shape) + sqrt_alpha_reciprocal_x = extract_index(sqrt_alpha_reciprocal, timestep, img.shape) + + mean = sqrt_alpha_reciprocal_x * (img - extract_index(beta, timestep, img.shape) * model(img, timestep) / sqrt_minus_alpha_cumprod_x) + + + if timestep == 0: + return mean + else: + E = torch.randn_like(img) + posterior_variance = beta * (1. - alpha_cumprod_rev) / (1.0 - alpha_cumprod) + + return mean + torch.sqrt(extract_index(posterior_variance, timestep, img.shape)) * E + +def get_sample_pos(size): + """ + Generates sampling tensor from input size and predefined sample range + + Args: + size (int): size to time step + + Returns: + Tensor: tensor full of random integers between 1 and 1000 with specified size + """ + return torch.randint(low=1, high=1000, size=(size,)) + +def extract_index(x, pos, x_shape): + """ + Returns a specific index, pos, in a tensor, x + + Args: + x (Tensor): input tensor + pos (Tensor): position tensor + x_shape (Size): shape of tensor + + Returns: + Tensor: index tensor + """ + batch_size = pos.shape[0] + output = x.gather(-1, pos.to("cuda")) + output = output.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to("cuda") + return output + +def save_loss_data(tracked_loss, test_loss): + """ + Save training and testing loss data to csv file + + Args: + tracked_loss (List): list of training loss values + test_loss (List): list of testing loss values + """ + # Save loss values + train_loss_file = open("Epoch Loss.csv", 'w') + writer = csv.writer(train_loss_file) + writer.writerows(tracked_loss) + + # Save test values + test_loss_file = open("Test Loss.csv", 'w') + writer = csv.writer(test_loss_file) + writer.writerows(test_loss) diff --git a/recognition/XUE4645768/Readme.md b/recognition/XUE4645768/Readme.md index 94bc1848c0..36250adaa3 100644 --- a/recognition/XUE4645768/Readme.md +++ b/recognition/XUE4645768/Readme.md @@ -53,52 +53,6 @@ python gcn.py Warning: Please pay attention to whether the data path is correct when you run the gcn.py. -# Training - -Learning rate= 0.01 -Weight dacay =0.005 - -For 200 epoches: -```Epoch 000: Loss 0.2894, TrainAcc 0.9126, ValAcc 0.8954 -Epoch 001: Loss 0.2880, TrainAcc 0.9126, ValAcc 0.895 -Epoch 002: Loss 0.2866, TrainAcc 0.9126, ValAcc 0.8961 -Epoch 003: Loss 0.2853, TrainAcc 0.9132, ValAcc 0.8961 -Epoch 004: Loss 0.2839, TrainAcc 0.9137, ValAcc 0.8961 -Epoch 005: Loss 0.2826, TrainAcc 0.9141, ValAcc 0.8963 -Epoch 006: Loss 0.2813, TrainAcc 0.9146, ValAcc 0.8956 -Epoch 007: Loss 0.2800, TrainAcc 0.9146, ValAcc 0.8956 -Epoch 008: Loss 0.2788, TrainAcc 0.9146, ValAcc 0.8959 -Epoch 009: Loss 0.2775, TrainAcc 0.9146, ValAcc 0.8970 -Epoch 010: Loss 0.2763, TrainAcc 0.915, ValAcc 0.8974 -Epoch 011: Loss 0.2751, TrainAcc 0.915, ValAcc 0.8972 -Epoch 012: Loss 0.2739, TrainAcc 0.915, ValAcc 0.8976 -Epoch 013: Loss 0.2727, TrainAcc 0.9157, ValAcc 0.8979 -Epoch 014: Loss 0.2716, TrainAcc 0.9157, ValAcc 0.8983 -Epoch 015: Loss 0.2704, TrainAcc 0.9161, ValAcc 0.8990 -Epoch 016: Loss 0.2693, TrainAcc 0.9168, ValAcc 0.8988 -Epoch 017: Loss 0.2682, TrainAcc 0.9181, ValAcc 0.8990 -Epoch 018: Loss 0.2671, TrainAcc 0.9179, ValAcc 0.8990 -Epoch 019: Loss 0.2660, TrainAcc 0.9179, ValAcc 0.8992 -Epoch 020: Loss 0.2650, TrainAcc 0.9188, ValAcc 0.8996 -...... -Epoch 190: Loss 0.1623, TrainAcc 0.9553, ValAcc 0.9134 -Epoch 191: Loss 0.1619, TrainAcc 0.9555, ValAcc 0.9134 -Epoch 192: Loss 0.1615, TrainAcc 0.9555, ValAcc 0.9132 -Epoch 193: Loss 0.1611, TrainAcc 0.9557, ValAcc 0.9130 -Epoch 194: Loss 0.1607, TrainAcc 0.9562, ValAcc 0.9130 -Epoch 195: Loss 0.1603, TrainAcc 0.9559, ValAcc 0.9130 -Epoch 196: Loss 0.1599, TrainAcc 0.9562, ValAcc 0.9126 -Epoch 197: Loss 0.1595, TrainAcc 0.9562, ValAcc 0.9123 -Epoch 198: Loss 0.1591, TrainAcc 0.9562, ValAcc 0.9123 -Epoch 199: Loss 0.1587, TrainAcc 0.9562, ValAcc 0.9123``` - -For test accuracy:around 0.9 - -# TSNE -For the test:iteration=500, with lower dimension to 2 - - - ```python