Standard high-level APIs like model.fit()
in Keras are powerful but abstract away the underlying training mechanics. The goal of this project is to demonstrate a deeper understanding of the training process by implementing a complete training and validation loop from scratch using TensorFlow's lower-level tools. This approach provides granular control over the model's training, which is essential for advanced architectures, research, and debugging.
This project builds and trains a Convolutional Neural Network (CNN) on the Fashion MNIST dataset without using the standard model.fit()
method. The entire process is manually defined to showcase proficiency with core TensorFlow components:
- Data Pipeline: The Fashion MNIST dataset was loaded using
tensorflow_datasets
, ensuring an efficient and scalable input pipeline. The data was batched, shuffled, and preprocessed by normalizing pixel values. - Model Definition: A sequential CNN was constructed using the Keras Functional API. The architecture includes
Conv2D
,MaxPooling2D
,Flatten
, andDense
layers, withDropout
for regularization. - Custom Training Loop: The core of the project is a custom training loop that iterates through epochs and batches. For each batch, the following steps are explicitly performed:
- Forward Pass: Predictions (logits) are generated by passing the input batch through the model.
- Loss Calculation: The
SparseCategoricalCrossentropy
loss between the true labels and the predictions is computed. - Gradient Computation:
tf.GradientTape
is used to automatically calculate the gradients of the loss with respect to the model's trainable weights. - Weight Update: The
Adam
optimizer applies these gradients to update the model's weights, minimizing the loss.
- Metrics and Validation:
tf.keras.metrics
(specificallySparseCategoricalAccuracy
) were used to track training and validation accuracy.- A separate validation function was created to evaluate the model's performance on the test set at the end of each epoch.
- The metrics are manually updated and reset at the beginning of each epoch.
- TensorFlow: For building the model, creating the custom training loop, and using
tf.GradientTape
. - TensorFlow Datasets (TFDS): For efficiently loading and preprocessing the Fashion MNIST dataset.
- NumPy: For data manipulation.
- Matplotlib: For visualizing the dataset and plotting performance metrics.
- Tqdm: To provide a progress bar for monitoring training steps.
The Fashion MNIST dataset is used for this project. It is a collection of 70,000 grayscale images (60,000 for training and 10,000 for testing) of 10 different types of clothing items (e.g., T-shirt, trouser, coat). Each image is 28x28 pixels. This dataset serves as a more challenging drop-in replacement for the original MNIST digit dataset.
- Clone the repository:
git clone <repository-url> cd <repository-directory>
- Install the required libraries:
pip install tensorflow tensorflow-datasets numpy tqdm matplotlib
- Run the Jupyter Notebook:
jupyter notebook dnn_custom_training.ipynb
The model was trained for 10 epochs using the custom loop. The performance metrics below demonstrate that the manually implemented training process was successful in teaching the model to classify the clothing items effectively.
- Final Validation Accuracy: ~88.45%
- Final Training Accuracy: ~89.10%
The loss curves show a stable decrease for both training and validation sets, indicating good model convergence.
Sample Images from the Fashion MNIST Dataset
Training Progress The output below shows the loss and accuracy metrics at the end of each epoch during the custom training loop.
Epoch 9: Train loss: 0.3117 Validation Loss: 0.3341, Train Accuracy: 0.8910, Validation Accuracy 0.8845
Loss vs. Epochs Plot
Sample Predictions on Test Data
This project was an excellent opportunity to look under the hood of TensorFlow's training process and solidify my understanding of the mechanics of deep learning.
- Demystifying
model.fit()
: Implementing the training loop manually made the steps involved inmodel.fit()
—such as the forward pass, gradient calculation, and weight updates—very clear and intuitive. - Mastering
tf.GradientTape
: I gained significant hands-on experience withtf.GradientTape
, which is a critical tool for any non-standard model or research application in TensorFlow. - Stateful Metrics: I learned how to properly manage stateful metrics like accuracy, including updating their state with each batch and resetting them between epochs. This is a subtle but crucial aspect of custom training.
This foundational knowledge is invaluable for debugging, implementing custom architectures (like GANs), and having the flexibility to go beyond standard training procedures when a project requires it.
💡 Some interactive outputs (e.g., plots, widgets) may not display correctly on GitHub. If so, please view this notebook via nbviewer.org for full rendering.
Email: imehranasgari@gmail.com
GitHub: https://github.com/imehranasgari
This project is licensed under the Apache 2.0 License – see the LICENSE
file for details.