Skip to content

Commit a4c9850

Browse files
cleanup distillation loss names (#21766)
* cleanup distillation api names * Update keras/src/distillation/distillation_loss.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * code reformat * update docstring --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent fb2f025 commit a4c9850

File tree

4 files changed

+162
-123
lines changed

4 files changed

+162
-123
lines changed

keras/src/distillation/distillation_loss.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,15 @@ def validate_model_compatibility(self, teacher, student):
7777
teacher: The teacher model.
7878
student: The student model.
7979
Raises:
80-
ValueError: If models are not compatible with this
81-
distillation_loss.
80+
ValueError: If models are not compatible with this distillation
81+
loss.
8282
"""
8383
pass
8484

8585

8686
@keras_export("keras.distillation.FeatureDistillation")
8787
class FeatureDistillation(DistillationLoss):
88-
"""Feature distillation distillation_loss.
88+
"""Feature distillation loss.
8989
9090
Feature distillation transfers knowledge from intermediate layers of the
9191
teacher model to corresponding layers of the student model. This approach
@@ -99,7 +99,7 @@ class FeatureDistillation(DistillationLoss):
9999
- Nested structure of losses matching the layer output structure
100100
- `None` to skip distillation for that output (useful for
101101
multi-output models where you only want to distill some outputs)
102-
At least one loss must be non-None. Defaults to 'mse'.
102+
At least one loss must be non-`None`. Defaults to 'mse'.
103103
teacher_layer_name: Name of the teacher layer to extract features from.
104104
If `None`, uses the final output. Defaults to `None`.
105105
student_layer_name: Name of the student layer to extract features from.
@@ -152,7 +152,10 @@ def __init__(
152152

153153
flat_losses = tree.flatten(self.loss)
154154
if all(l is None for l in flat_losses):
155-
raise ValueError("At least one loss must be non-None.")
155+
raise ValueError(
156+
"The `loss` argument in `FeatureDistillation` must "
157+
"contain at least one non-`None` value."
158+
)
156159

157160
def validate_model_compatibility(self, teacher, student):
158161
"""Validate that teacher and student models are compatible for feature
@@ -258,7 +261,7 @@ def from_config(cls, config):
258261
class LogitsDistillation(DistillationLoss):
259262
"""Distillation loss that transfers knowledge from final model outputs.
260263
261-
This distillation_loss applies temperature scaling to the teacher's logits
264+
This distillation loss applies temperature scaling to the teacher's logits
262265
before computing the loss between teacher and student predictions. It's the
263266
most common approach for knowledge distillation.
264267

keras/src/distillation/distillation_loss_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_logits_distillation_end_to_end(self):
100100
distiller = Distiller(
101101
teacher=self.teacher,
102102
student=self.student,
103-
distillation_loss=LogitsDistillation(temperature=3.0),
103+
distillation_losses=LogitsDistillation(temperature=3.0),
104104
student_loss_weight=0.5,
105105
)
106106

@@ -138,7 +138,7 @@ def test_feature_distillation_end_to_end(self):
138138
distiller = Distiller(
139139
teacher=self.teacher,
140140
student=self.student,
141-
distillation_loss=FeatureDistillation(
141+
distillation_losses=FeatureDistillation(
142142
loss="mse",
143143
teacher_layer_name="teacher_dense_1",
144144
student_layer_name="student_dense_1",
@@ -194,7 +194,7 @@ def test_multi_distillation_loss_distillation_end_to_end(self):
194194
distiller = Distiller(
195195
teacher=self.teacher,
196196
student=self.student,
197-
distillation_loss=distillation_loss,
197+
distillation_losses=distillation_loss,
198198
distillation_loss_weights=[1.0, 0.5, 0.3],
199199
student_loss_weight=0.5,
200200
)

0 commit comments

Comments
 (0)