Skip to content

An end-to-end implementation of a custom training and validation loop for a CNN on the Fashion MNIST dataset. This project demonstrates low-level model training using tf.GradientTape and tf.keras.metrics, without relying on model.fit().

License

Notifications You must be signed in to change notification settings

imehranasgari/DL_TensorFlow_LowLevelAPI_CustomTrainingLoop

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Custom DNN Training Loop for Fashion MNIST Classification

Problem Statement and Goal of Project

Standard high-level APIs like model.fit() in Keras are powerful but abstract away the underlying training mechanics. The goal of this project is to demonstrate a deeper understanding of the training process by implementing a complete training and validation loop from scratch using TensorFlow's lower-level tools. This approach provides granular control over the model's training, which is essential for advanced architectures, research, and debugging.

Solution Approach

This project builds and trains a Convolutional Neural Network (CNN) on the Fashion MNIST dataset without using the standard model.fit() method. The entire process is manually defined to showcase proficiency with core TensorFlow components:

  1. Data Pipeline: The Fashion MNIST dataset was loaded using tensorflow_datasets, ensuring an efficient and scalable input pipeline. The data was batched, shuffled, and preprocessed by normalizing pixel values.
  2. Model Definition: A sequential CNN was constructed using the Keras Functional API. The architecture includes Conv2D, MaxPooling2D, Flatten, and Dense layers, with Dropout for regularization.
  3. Custom Training Loop: The core of the project is a custom training loop that iterates through epochs and batches. For each batch, the following steps are explicitly performed:
    • Forward Pass: Predictions (logits) are generated by passing the input batch through the model.
    • Loss Calculation: The SparseCategoricalCrossentropy loss between the true labels and the predictions is computed.
    • Gradient Computation: tf.GradientTape is used to automatically calculate the gradients of the loss with respect to the model's trainable weights.
    • Weight Update: The Adam optimizer applies these gradients to update the model's weights, minimizing the loss.
  4. Metrics and Validation:
    • tf.keras.metrics (specifically SparseCategoricalAccuracy) were used to track training and validation accuracy.
    • A separate validation function was created to evaluate the model's performance on the test set at the end of each epoch.
    • The metrics are manually updated and reset at the beginning of each epoch.

Technologies & Libraries

  • TensorFlow: For building the model, creating the custom training loop, and using tf.GradientTape.
  • TensorFlow Datasets (TFDS): For efficiently loading and preprocessing the Fashion MNIST dataset.
  • NumPy: For data manipulation.
  • Matplotlib: For visualizing the dataset and plotting performance metrics.
  • Tqdm: To provide a progress bar for monitoring training steps.

Description about Dataset

The Fashion MNIST dataset is used for this project. It is a collection of 70,000 grayscale images (60,000 for training and 10,000 for testing) of 10 different types of clothing items (e.g., T-shirt, trouser, coat). Each image is 28x28 pixels. This dataset serves as a more challenging drop-in replacement for the original MNIST digit dataset.

Installation & Execution Guide

  1. Clone the repository:
    git clone <repository-url>
    cd <repository-directory>
  2. Install the required libraries:
    pip install tensorflow tensorflow-datasets numpy tqdm matplotlib
  3. Run the Jupyter Notebook:
    jupyter notebook dnn_custom_training.ipynb

Key Results / Performance

The model was trained for 10 epochs using the custom loop. The performance metrics below demonstrate that the manually implemented training process was successful in teaching the model to classify the clothing items effectively.

  • Final Validation Accuracy: ~88.45%
  • Final Training Accuracy: ~89.10%

The loss curves show a stable decrease for both training and validation sets, indicating good model convergence.

Screenshots / Sample Output

Sample Images from the Fashion MNIST Dataset

Training Progress The output below shows the loss and accuracy metrics at the end of each epoch during the custom training loop.

Epoch 9: Train loss: 0.3117  Validation Loss: 0.3341, Train Accuracy: 0.8910, Validation Accuracy 0.8845

Loss vs. Epochs Plot

Sample Predictions on Test Data

Additional Learnings / Reflections

This project was an excellent opportunity to look under the hood of TensorFlow's training process and solidify my understanding of the mechanics of deep learning.

  • Demystifying model.fit(): Implementing the training loop manually made the steps involved in model.fit()—such as the forward pass, gradient calculation, and weight updates—very clear and intuitive.
  • Mastering tf.GradientTape: I gained significant hands-on experience with tf.GradientTape, which is a critical tool for any non-standard model or research application in TensorFlow.
  • Stateful Metrics: I learned how to properly manage stateful metrics like accuracy, including updating their state with each batch and resetting them between epochs. This is a subtle but crucial aspect of custom training.

This foundational knowledge is invaluable for debugging, implementing custom architectures (like GANs), and having the flexibility to go beyond standard training procedures when a project requires it.


💡 Some interactive outputs (e.g., plots, widgets) may not display correctly on GitHub. If so, please view this notebook via nbviewer.org for full rendering.

👤 Author

Mehran Asgari


📄 License

This project is licensed under the Apache 2.0 License – see the LICENSE file for details.

About

An end-to-end implementation of a custom training and validation loop for a CNN on the Fashion MNIST dataset. This project demonstrates low-level model training using tf.GradientTape and tf.keras.metrics, without relying on model.fit().

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published