This project demonstrates a practical and in-depth understanding of low-level TensorFlow and Keras APIs by implementing custom training loops for two distinct image classification tasks.
It begins with a foundational CNN training pipeline on the Eurosat dataset using tf.GradientTape
and progresses to an optimized MLP for MNIST classification with advanced performance-enhancing techniques.
High-level APIs such as model.fit()
abstract much of the training process, which is convenient but limits flexibility when implementing custom logic, advanced optimizers, or non-standard evaluation metrics.
This project addresses that limitation by manually managing the training process, offering complete control over each step.
Key objectives:
- Foundational Implementation: Train a CNN from scratch on Eurosat satellite images with a custom
GradientTape
loop. - Performance Optimization: Apply architectural and training improvements to a MNIST classifier to enhance accuracy and robustness.
Two Jupyter Notebooks are provided to illustrate the transition from basic custom training to optimized, production-ready training.
File: simple_just_for_learning_not_metric.ipynb
-
Dataset Handling: Loaded Eurosat from
tensorflow_datasets
; split into 70% train, 15% validation, 15% test. -
Data Pipeline: Built with
tf.data
for efficient batching, shuffling, and prefetching. -
Preprocessing:
- Resize to
64×64
- Normalize to
[0, 1]
- One-hot encode labels
- Resize to
-
Augmentation: Random flips, rotations, zooms, contrast adjustments (applied on-the-fly).
-
Model: CNN architecture implemented manually.
-
Training Loop: Implemented with
tf.GradientTape
, including:- Gradient computation & manual weight updates
- Validation monitoring
- Model checkpoint saving
- Early stopping
-
Monitoring: Integrated TensorBoard for loss/accuracy visualization.
File: low_level_api_better_acc.ipynb
-
Dataset Handling: Loaded MNIST via
keras.datasets
. -
Loss Function:
SparseCategoricalCrossentropy
(works with integer labels, memory-efficient). -
Architecture: Multi-Layer Perceptron with:
- Batch Normalization (stabilizes and accelerates training)
- Dropout (reduces overfitting)
-
Training Loop Enhancements:
- Early Stopping (manual implementation)
- ReduceLROnPlateau (learning rate adjustment on validation loss plateau)
-
Evaluation:
- Test set accuracy
- Confusion Matrix for per-class performance
- Frameworks: TensorFlow, Keras
- Libraries: TensorFlow Datasets, NumPy, Matplotlib, scikit-learn, Tqdm
- Tools: Jupyter Notebook, TensorBoard
- Eurosat: 27,000 labeled satellite images (
64×64 px
, RGB) in 10 land use classes (e.g., Forest, River, Industrial). - MNIST: 70,000 grayscale handwritten digits (
28×28 px
, classes 0–9).
# 1. Clone the repository
git clone https://github.com/imehranasgari/your-repo-name.git
cd your-repo-name
# 2. Install dependencies
pip install -r requirements.txt
# 3. Launch Jupyter Notebook
jupyter notebook
Open either:
simple_just_for_learning_not_metric.ipynb
low_level_api_better_acc.ipynb
-
Eurosat Notebook:
- Robust, reusable training pipeline with augmentation and monitoring.
- Demonstrates complete manual training loop.
-
MNIST Notebook:
- Significant accuracy improvement with Batch Normalization, Dropout, and LR scheduling.
- Clear per-class breakdown via confusion matrix.
(Use your own prepared screenshots for clarity — examples include:)
- Eurosat dataset sample images
- MNIST training/validation curves
- MNIST confusion matrix
-
Moving beyond
model.fit()
provided deeper insight into:- Gradient descent & backpropagation
- Metric calculation
- Manual control over optimization flow
-
The first notebook emphasized pipeline building; the second showcased model optimization for higher accuracy.
-
Some notebooks intentionally use simpler models or achieve lower metrics — these are learning exercises, not production constraints.
Mehran Asgari Email: imehranasgari@gmail.com GitHub: https://github.com/imehranasgari
This project is licensed under the Apache 2.0 License – see the LICENSE
file for details.
💡 Some interactive outputs (e.g., plots, widgets) may not display correctly on GitHub. If so, please view this notebook via nbviewer.org for full rendering.