Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions recognition/Yeheng_Sun_S4548085_COMP3710_ADNIClassifier/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Yeheng_Sun_S4548085_COMP3710_ADNIClassifier

## The code is written to provide a solution for problem 6. The detailed description is listed below:

# Problem Defination

Classify Alzheimer’s disease (normal and AD) of the ADNI brain dataset using a visual transformer.

# Algorithm Description(Concept)

The model consists of three main modules. The first module is the Patch Encoders module. It consists of a patch layer which turns images into patches, followed by a patch encoder layer which encodes patches into vectors. The second is the transformer module, this module measure the relationships between pairs of input vectors. Finally, there is a multi-layer perceptron module which has 2 layers with 1024 x 512 neurons that act as a classifier. Combine all the modules together and they make up the Vision Transformer model.

# Train Procedure

The original image is 256 x 240 pixels. I resize the image into 256 X 256 pixels such that the input shape of the model is (256, 256, 3) in which 3 stands for RGB. The metric I used is Binary Accuracy, and the loss function I used is Binary Cross-entropy. The number of heads of the multi-head attension layer is 4. I used Adam optimizer and I set the learning rate to .0003 with the ReduceLROnPlateau function implemented such that when the loss remains unchanged, the learning rate will be reduced.

I trained the model through 100 ephocs.

The dataset had already been divided into training, validation, and test sets. The validation dataset is useful during training to monitor for overfitting, and the test dataset was used to assess model generalisation capability on a set not seen during training.


# Structure of the project

**“README.MD”** is a file you are reading, provide detailed description of the dataset and scripts.

**ADNI_AD_NC_2D** is a directory containing the ADNI brain image dataset, the directories structure and corresponding descriptions are shown below:

```
ADNI_AD_NC_2D # a directory containing the ADNI brain image dataset
└─AD_NC # a sub-directroy
├─test # a directory containing the test set
│ ├─AD # a directory containing all the brain images with Alzheimer’s disease in the test set
│ └─NC # a directory containing all cognitively normal brain images in the test set
└─train # a directory containing the train set
├─AD # a directory containing all the brain images with Alzheimer’s disease in the train set
└─NC # a directory containing all cognitively normal brain images in the train set
```

**requirments.txt** is a txt file containing all required dependencies for a specific version

**“dataset.py"** containing functions for loading train and test images, the dataset directory 'DATADIR_train' and 'DATADIR_test' should be modified with proper paths

**“modules.py"** containing the source code of the visual transformer, including the implementaion of 'Patches' and 'PatchEncoder' class, 'mlp' and 'create_vit_classifier' funtion.

- 'Patches' class is used to split a raw imag into patches
- 'PatchEncoder' class is used to encode patches into vectors
- 'mlp' is the implementaion of multilayer perceptron, which place within the visual transformer
- 'create_vit_classifier' is the implementaion of visual transformer itself

**“train.py"** containing the source code for model training, to ensure reproduciblility, please remain all the parameters unchange. 'run_experiment' is the function defining optimizer and checkpoint of the model, while the training process of the model is implemented within.

**“predict.py"** containing source code for printing and ploting model performance in test set.


# Preprocesssing
The original dataset have train and test directories. The images in train directory is for training and validation. The images in test directory are for testing. I split the images in train directory into train data and validation data with the ratio 7:3 respectively.

# Changes compared to original ViT model
In module.py, I remove data augmentation layers such that original image data (256, 256, 3) will be directly split into patches and encoded into vectors. In this setting, it will increase the train accuracy and result in an acceptable test accuracy. In addition, I increase the patch size and reduce the number of patches, as the number of patches is equal to (image_size // patch_len) ** 2, such that the number of patches is reduced from 400+ to 20+. Furthermore, I implemented the ReduceLROnPlateau function such that when the loss remains unchanged, the learning rate will be reduced. Finally, to keep the model from becoming too complicated, I reduce the neurons in the MLP layers from \[2048,1024\] to \[1024,512\].

# Experiment Reproducible Step
- Download dataset from https://cloudstor.aarnet.edu.au/plus/s/L6bbssKhUoUdTSI/download and unzip it into the root directory.
- Install all the dependencies in requirements.txt.
- Make sure you cd to the root directory and the structure of dataset directory is ./ADNI_AD_NC_2D/AD_NC/.
- Run the train.py file to start training, after the training run the predict.py file to evalute the model.

# Output
After train and save the weights of the model, we are able to use it for prediction. Load the file "model.h5", the corresponding model weights will be loaded into the model. We use the test set for evalution and having 83.27% accuracy.

### Train Accuracy
![accuracy.png](https://i.postimg.cc/T3MqvL7w/accuracy.png)

### Train Loss
![loss.png](https://i.postimg.cc/tC5ZJ1YM/loss.png)


# Example of Dataset Image

### Alzheimer’s disease image
![218391-78.jpg](https://i.postimg.cc/2jLcjkKh/218391-78.jpg)

### Normal image
![808819-88.jpg](https://i.postimg.cc/Gh8CCnk5/808819-88.jpg)
39 changes: 39 additions & 0 deletions recognition/Yeheng_Sun_S4548085_COMP3710_ADNIClassifier/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# import packages
from random import seed
import numpy as np

from tensorflow import keras
import warnings

# ignore warning
warnings.filterwarnings('ignore')

train_data_dir = './ADNI_AD_NC_2D/AD_NC/train' # train data directory
test_data_dir = './ADNI_AD_NC_2D/AD_NC/test' # test data directory
class_name_list = ['AD', 'NC'] # list of class name
img_size = 256

# train and validation data loader
def createTrainData(img_size, batch_size):
train_ds = keras.utils.image_dataset_from_directory(
directory=train_data_dir, # target data directory
labels='inferred', # data is tagged according to its directory
label_mode='binary', # only 2 classes, tagged with value 0 or 1
batch_size=batch_size,
image_size=(img_size, img_size), # the size after resize
subset = 'validation', ## create validation set
validation_split = 0.3, ## 30% of train data into validation
seed = 77
)
return train_ds

# test data loader
def createTestData(img_size, batch_size):
test_ds = keras.utils.image_dataset_from_directory(
directory=test_data_dir,
labels='inferred',
label_mode='binary',
batch_size=batch_size,
image_size=(img_size, img_size),
)
return test_ds
81 changes: 81 additions & 0 deletions recognition/Yeheng_Sun_S4548085_COMP3710_ADNIClassifier/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# import packages
import tensorflow as tf
from tensorflow import keras
from keras.layers import Normalization, Resizing, RandomFlip, RandomRotation, RandomZoom, Dense, Dropout, \
Layer, Embedding, Input, LayerNormalization, MultiHeadAttention, Add, LayerNormalization, Flatten

img_size = 256 # must match train.py and dataset.py image size



# define MLP given a list which record the number of nodes in each layer
def multi_layer_preceptron(x, layer_list, drop_out_rate):
for layer_node in layer_list:
x = Dense(layer_node, activation=tf.nn.gelu)(x)
x = Dropout(drop_out_rate)(x)
return x


# a class used to encode each patch into vector
class Patch2Vec(Layer):
def __init__(self, patch_n, proj_vec_n):
super(Patch2Vec, self).__init__()
self.patch_n = patch_n
self.proj_layer = Dense(units=proj_vec_n)
self.position_embed_layer = Embedding(
input_dim=patch_n, output_dim=proj_vec_n
)

def call(self, patch):
position = tf.range(start=0, limit=self.patch_n, delta=1)
encode_vec = self.proj_layer(patch) + self.position_embed_layer(position)
return encode_vec


# a class to split images into patches
class Patches(Layer):
def __init__(self, patch_len):
super(Patches, self).__init__()
self.patch_len = patch_len

def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_len, self.patch_len, 1],
strides=[1, self.patch_len, self.patch_len, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_n = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_n])
return patches


def vision_transformer(input_shape, patch_len, patch_n, proj_vec_n, transformer_n, head_n, class_n,
transformer_units, mlp_head_units):

# data augmentation and patch operation
input_img = Input(shape=input_shape)
patch_img = Patches(patch_len)(input_img)
patch_vec = Patch2Vec(patch_n, proj_vec_n)(patch_img)

# transformer modules
for _ in range(transformer_n):
x1 = LayerNormalization()(patch_vec)
attention_output = MultiHeadAttention(
num_heads=head_n, key_dim=proj_vec_n, dropout=0.1
)(x1, x1)
x2 = Add()([attention_output, patch_vec])
x3 = LayerNormalization()(x2)
x3 = multi_layer_preceptron(x3, layer_list=transformer_units, drop_out_rate=0.1)
patch_vec = Add()([x3, x2])

# MLP classifier
feature = LayerNormalization()(patch_vec)
feature = Flatten()(feature)
feature = Dropout(0.5)(feature)
feature = multi_layer_preceptron(feature, layer_list=mlp_head_units, drop_out_rate=0.5)
output = Dense(class_n, activation='sigmoid')(feature)
model = keras.Model(inputs=input_img, outputs=output)
return model
16 changes: 16 additions & 0 deletions recognition/Yeheng_Sun_S4548085_COMP3710_ADNIClassifier/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import imp
import numpy as np
from tensorflow import keras
import tensorflow_addons as tfa
import module
import dataset

test_ds = dataset.createTestData()
model = module.createModel()

# Loads the weights
model.load_weights('./utils/model.h5')

# Re-evaluate the model
_, acc = model.evaluate(test_ds)
print(f"Test acc: {round(acc * 100, 2)}%")
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
absl-py==1.3.0
anyio @ file:///C:/ci/anyio_1644481921011/work/dist
argon2-cffi @ file:///opt/conda/conda-bld/argon2-cffi_1645000214183/work
argon2-cffi-bindings @ file:///C:/ci/argon2-cffi-bindings_1644551690056/work
asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work
astunparse==1.6.3
attrs @ file:///opt/conda/conda-bld/attrs_1642510447205/work
Babel @ file:///tmp/build/80754af9/babel_1620871417480/work
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
beautifulsoup4 @ file:///C:/ci/beautifulsoup4_1650293025093/work
bleach @ file:///opt/conda/conda-bld/bleach_1641577558959/work
brotlipy==0.7.0
cachetools==5.2.0
certifi @ file:///C:/b/abs_ac29jvt43w/croot/certifi_1665076682579/work/certifi
cffi @ file:///C:/Windows/Temp/abs_6808y9x40v/croots/recipe/cffi_1659598653989/work
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
colorama @ file:///C:/Windows/TEMP/abs_9439aeb1-0254-449a-96f7-33ab5eb17fc8apleb4yn/croots/recipe/colorama_1657009099097/work
contourpy==1.0.5
cryptography @ file:///C:/ci/cryptography_1652083563162/work
cycler==0.11.0
debugpy @ file:///C:/ci/debugpy_1637091961445/work
decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
entrypoints @ file:///C:/ci/entrypoints_1649926621128/work
executing @ file:///opt/conda/conda-bld/executing_1646925071911/work
fastjsonschema @ file:///C:/Users/BUILDE~1/AppData/Local/Temp/abs_ebruxzvd08/croots/recipe/python-fastjsonschema_1661376484940/work
flatbuffers==2.0.7
fonttools==4.38.0
gast==0.4.0
google-auth==2.13.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio==1.50.0
h5py==3.7.0
idna @ file:///C:/b/abs_bdhbebrioa/croot/idna_1666125572046/work
importlib-metadata @ file:///C:/ci/importlib-metadata_1648562621412/work
ipykernel @ file:///C:/b/abs_21ykzkm7y_/croots/recipe/ipykernel_1662361803478/work
ipython @ file:///C:/ci/ipython_1657634415474/work
ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1634143127070/work
jedi @ file:///C:/ci/jedi_1644315428289/work
Jinja2 @ file:///opt/conda/conda-bld/jinja2_1647436528585/work
joblib==1.2.0
json5 @ file:///tmp/build/80754af9/json5_1624432770122/work
jsonschema @ file:///C:/b/abs_59eyhnbyej/croots/recipe/jsonschema_1663375476535/work
jupyter @ file:///C:/Windows/TEMP/abs_56xfdi__li/croots/recipe/jupyter_1659349053177/work
jupyter-console @ file:///opt/conda/conda-bld/jupyter_console_1647002188872/work
jupyter-server @ file:///C:/Windows/TEMP/abs_d3c42c59-765d-4f9b-9fa3-ad5b1369485611i_yual/croots/recipe/jupyter_server_1658754493238/work
jupyter_client @ file:///C:/b/abs_8fbm7986b_/croots/recipe/jupyter_client_1662504374117/work
jupyter_core @ file:///C:/b/abs_a9330r1z_i/croots/recipe/jupyter_core_1664917313457/work
jupyterlab @ file:///C:/ci/jupyterlab_1658891142428/work
jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work
jupyterlab_server @ file:///C:/ci/jupyterlab_server_1664893164497/work
keras==2.7.0
Keras-Preprocessing==1.1.2
kiwisolver==1.4.4
libclang==14.0.6
Markdown==3.4.1
MarkupSafe @ file:///C:/ci/markupsafe_1654508077284/work
matplotlib==3.6.1
matplotlib-inline @ file:///C:/ci/matplotlib-inline_1661915841596/work
mistune @ file:///C:/ci/mistune_1607359457024/work
nbclassic @ file:///opt/conda/conda-bld/nbclassic_1644943264176/work
nbclient @ file:///C:/ci/nbclient_1650290387259/work
nbconvert @ file:///C:/ci/nbconvert_1649741016669/work
nbformat @ file:///C:/b/abs_1dw90o2uqb/croots/recipe/nbformat_1663744957967/work
nest-asyncio @ file:///C:/ci/nest-asyncio_1649829929390/work
notebook @ file:///C:/Windows/TEMP/abs_79abr1_60s/croots/recipe/notebook_1659083661851/work
numpy==1.23.4
oauthlib==3.2.2
opencv-python==4.6.0.66
opt-einsum==3.3.0
packaging @ file:///tmp/build/80754af9/packaging_1637314298585/work
pandas==1.5.1
pandocfilters @ file:///opt/conda/conda-bld/pandocfilters_1643405455980/work
parso @ file:///opt/conda/conda-bld/parso_1641458642106/work
pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
Pillow==9.2.0
ply==3.11
prometheus-client @ file:///C:/Windows/TEMP/abs_ab9nx8qb08/croots/recipe/prometheus_client_1659455104602/work
prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1633440160888/work
protobuf==3.19.6
psutil @ file:///C:/Windows/Temp/abs_b2c2fd7f-9fd5-4756-95ea-8aed74d0039flsd9qufz/croots/recipe/psutil_1656431277748/work
pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
Pygments @ file:///opt/conda/conda-bld/pygments_1644249106324/work
pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work
pyparsing @ file:///C:/Users/BUILDE~1/AppData/Local/Temp/abs_7f_7lba6rl/croots/recipe/pyparsing_1661452540662/work
PyQt5==5.15.7
PyQt5-sip @ file:///C:/Windows/Temp/abs_d7gmd2jg8i/croots/recipe/pyqt-split_1659273064801/work/pyqt_sip
pyrsistent @ file:///C:/ci/pyrsistent_1636093225342/work
PySocks @ file:///C:/ci/pysocks_1605307512533/work
python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
pytz @ file:///C:/Windows/TEMP/abs_90eacd4e-8eff-491e-b26e-f707eba2cbe1ujvbhqz1/croots/recipe/pytz_1654762631027/work
pywin32==302
pywinpty @ file:///C:/ci_310/pywinpty_1644230983541/work/target/wheels/pywinpty-2.0.2-cp39-none-win_amd64.whl
pyzmq @ file:///C:/ci/pyzmq_1657615952984/work
qtconsole @ file:///C:/ci/qtconsole_1662019208635/work
QtPy @ file:///C:/ci/qtpy_1662015096047/work
requests @ file:///C:/ci/requests_1657735342357/work
requests-oauthlib==1.3.1
rsa==4.9
scikit-learn==1.1.2
scipy==1.9.3
Send2Trash @ file:///tmp/build/80754af9/send2trash_1632406701022/work
sip @ file:///C:/Windows/Temp/abs_b8fxd17m2u/croots/recipe/sip_1659012372737/work
six @ file:///tmp/build/80754af9/six_1644875935023/work
sklearn==0.0
sniffio @ file:///C:/ci/sniffio_1614030527509/work
soupsieve @ file:///C:/b/abs_fasraqxhlv/croot/soupsieve_1666296394662/work
stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
tensorboard==2.10.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.7.0
tensorflow-addons==0.18.0
tensorflow-estimator==2.7.0
tensorflow-io-gcs-filesystem==0.27.0
termcolor==2.0.1
terminado @ file:///C:/ci/terminado_1644322780199/work
testpath @ file:///C:/Windows/TEMP/abs_23c7fa33-cbb9-46dc-b7c5-590c38e2de3d4bmbngal/croots/recipe/testpath_1655908553202/work
threadpoolctl==3.1.0
toml @ file:///tmp/build/80754af9/toml_1616166611790/work
tornado @ file:///C:/ci/tornado_1662458743919/work
traitlets @ file:///tmp/build/80754af9/traitlets_1636710298902/work
typeguard==2.13.3
typing_extensions @ file:///C:/Windows/TEMP/abs_dd2d0moa85/croots/recipe/typing_extensions_1659638831135/work
urllib3 @ file:///C:/b/abs_a8_3vfznn_/croot/urllib3_1666298943664/work
wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
webencodings==0.5.1
websocket-client @ file:///C:/ci/websocket-client_1614804375980/work
Werkzeug==2.2.2
widgetsnbextension @ file:///C:/ci/widgetsnbextension_1644991377168/work
win-inet-pton @ file:///C:/ci/win_inet_pton_1605306162074/work
wincertstore==0.2
wrapt==1.14.1
zipp @ file:///C:/ci/zipp_1652273994994/work
Loading