Keras Functional API: CNNs, Multi-Input/Output Models, and a Toy ResNet (TFDS: MNIST & CIFAR-10)
Demonstrate practical use of TensorFlow/Keras Functional API for:
- A baseline CNN classifier on MNIST,
- A multi-input, multi-output model (ticket triage example with synthetic data),
- A small ResNet-style (“skip connection”) model trained briefly on CIFAR-10,
with tf.data pipelines via TensorFlow Datasets (TFDS). The notebook focuses on correct API usage, modeling patterns, and explainability over squeezing out top metrics.
The notebook implements three modeling patterns:
-
Baseline CNN (MNIST)
- Input shape: (28, 28, 1)
- Architecture (Functional API):
Conv2D(32, 3x3, ReLU) → MaxPool → Conv2D(64, 3x3, ReLU) → MaxPool → Flatten → Dropout(0.5) → Dense(num_classes, softmax) - Loss set via
CategoricalCrossentropy(from_logits=True)withmetrics=["accuracy"/"acc"]in different cells. - Data via TFDS, normalized and one-hot encoded.
-
Multi-Input, Multi-Output model (synthetic)
- Perso-language markdown describes a ticket-ranking use-case (title/body/tags → priority/department).
- Code builds a Keras model with multiple inputs and two outputs (e.g.,
"priority","department"). - Trains on synthetic NumPy arrays (dummy data).
- Uses dictionary losses and
loss_weightsto balance heads. - A diagram is generated with
keras.utils.plot_model("multi_input_and_output_model.png").
-
Toy ResNet-style model (CIFAR-10)
- CIFAR-10 loaded from TFDS with splits:
train[:80%](train),train[80%:](val), andtest. - A small skip-connection (“Rednet/ResNet”) topology is defined using the Functional API (residual pattern).
- Training is run for a very small number of epochs (e.g.,
epochs=1) to keep runtime short. - A diagram is generated with
keras.utils.plot_model("resnet_me.png", show_shapes=True).
- CIFAR-10 loaded from TFDS with splits:
Across all parts:
- TFDS utilities (e.g.,
tfds.visualization.show_examples) are used to inspect samples. - Preprocessing uses
tf.dividefor normalization andtf.one_hot(..., depth=10)for labels. - Batching and shuffling are handled via
tf.data(e.g.,.map(...).shuffle(1000).batch(64)). - Optimizer examples include
keras.optimizers.RMSprop(1e-3).
Note: Some code paths intentionally restrict sample counts or epochs (e.g., 1–2 epochs) to demonstrate API usage and keep execution time low.
- Python (conda kernel name indicates
conda-env-tf-py) - TensorFlow / Keras
- TensorFlow Datasets (tfds)
- NumPy
-
MNIST (from TFDS)
- Split in code:
train,test - Input shape used: (28, 28, 1)
- Labels one-hot encoded with depth 10 (matches MNIST classes).
- Split in code:
-
CIFAR-10 (from TFDS)
- Splits in code:
train[:80%](train),train[80%:](validation),test - Model expects input shape (32, 32, 3).
- Splits in code:
-
Sample visualization via
tfds.visualization.show_examples(...).
If additional dataset specifics (e.g., class names) are needed, they’re standard to MNIST/CIFAR-10 but not explicitly enumerated in the notebook.
Prerequisites
- Python 3.x
- A working TensorFlow environment (GPU optional)
Install (pip)
pip install tensorflow tensorflow-datasets numpy pydot graphviz
graphvizandpydotare required to save the model diagrams viakeras.utils.plot_model. On some systems you must also install Graphviz binaries (e.g.,sudo apt-get install graphviz).
Run
-
Open
functional_API_me.ipynbin Jupyter/VS Code. -
Run cells sequentially.
-
The notebook will:
- Download MNIST and CIFAR-10 from TFDS on first run.
- Train each demo briefly.
- Save model diagrams (e.g.,
my_first_model_with_shape_info.png,multi_input_and_output_model.png,resnet_me.png) when those cells are executed.
- Not provided.
The notebook compiles and fits models (often for 1–2 epochs). It prints evaluation for the MNIST CNN (
model.evaluate(...)) and runs a brief fit on CIFAR-10, but no final metrics are stored in the repository.
- Not provided in the repository.
The code includes calls to
keras.utils.plot_model(...)which will create diagrams locally if executed.
- Why Functional API? It enables non-linear graphs, multiple inputs/outputs, and skip connections—patterns that aren’t possible (or are awkward) in
Sequential. - Loss dictionaries & loss weights: The multi-head example shows how to balance objectives with
{"priority": ..., "department": ...}andloss_weights. - Data pipelines with TFDS +
tf.data: The project demonstrates clean mapping, normalization, one-hot encoding, shuffling, and batching directly from TFDS. - Exploration over SOTA: Several cells intentionally run short trainings or use synthetic data to teach modeling mechanics (e.g., model wiring, compilation options, plotting architectures) rather than to maximize accuracy.
Mehran Asgari Email: imehranasgari@gmail.com GitHub: https://github.com/imehranasgari
This project is licensed under the Apache 2.0 License – see the LICENSE file for details.
💡 Some interactive outputs (e.g., plots, widgets) may not display correctly on GitHub. If so, please view this notebook via nbviewer.org for full rendering.