Skip to content

Commit dc3aa06

Browse files
authored
fix typing of num_updates in moving_average (#2136)
* fix typing of num_updates in moving_average * use .format instead of fstring
1 parent 9acab6a commit dc3aa06

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

tensorflow_addons/optimizers/moving_average.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tensorflow_addons.optimizers import AveragedOptimizerWrapper
1919
from tensorflow_addons.utils import types
2020

21-
from typing import Optional
21+
from typing import Union
2222
from typeguard import typechecked
2323

2424

@@ -47,7 +47,7 @@ def __init__(
4747
optimizer: types.Optimizer,
4848
sequential_update: bool = True,
4949
average_decay: types.FloatTensorLike = 0.99,
50-
num_updates: Optional[str] = None,
50+
num_updates: Union[None, int, tf.Variable] = None,
5151
start_step: int = 0,
5252
dynamic_decay: bool = False,
5353
name: str = "MovingAverage",
@@ -82,6 +82,14 @@ def __init__(
8282
super().__init__(optimizer, sequential_update, name, **kwargs)
8383
self._num_updates = num_updates
8484
if self._num_updates is not None:
85+
if isinstance(self._num_updates, tf.Variable):
86+
tf.debugging.assert_integer(
87+
self._num_updates,
88+
(
89+
'type of argument "num_updates" must be '
90+
"int; got {} instead".format(self._num_updates.dtype)
91+
),
92+
)
8593
num_updates = tf.cast(self._num_updates, tf.float32, name="num_updates")
8694
average_decay = tf.minimum(
8795
average_decay, (1.0 + num_updates) / (10.0 + num_updates)

tensorflow_addons/optimizers/tests/moving_average_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,19 @@ def test_opt_failure():
6868
MovingAverage(base_opt, 0.5)
6969

7070

71+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
72+
def test_num_updates_valid():
73+
for num_updates in [1, tf.Variable(1)]:
74+
MovingAverage("sgd", num_updates=num_updates)
75+
76+
77+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
78+
def test_num_updates_invalid():
79+
for num_updates in [1.0, tf.Variable(1.0), "a"]:
80+
with pytest.raises(TypeError):
81+
MovingAverage("sgd", num_updates=num_updates)
82+
83+
7184
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
7285
def test_model_weights_update():
7386
grad = tf.Variable([[0.1]])

0 commit comments

Comments
 (0)