|
18 | 18 | from tensorflow_addons.optimizers import AveragedOptimizerWrapper |
19 | 19 | from tensorflow_addons.utils import types |
20 | 20 |
|
21 | | -from typing import Optional |
| 21 | +from typing import Union |
22 | 22 | from typeguard import typechecked |
23 | 23 |
|
24 | 24 |
|
@@ -47,7 +47,7 @@ def __init__( |
47 | 47 | optimizer: types.Optimizer, |
48 | 48 | sequential_update: bool = True, |
49 | 49 | average_decay: types.FloatTensorLike = 0.99, |
50 | | - num_updates: Optional[str] = None, |
| 50 | + num_updates: Union[None, int, tf.Variable] = None, |
51 | 51 | start_step: int = 0, |
52 | 52 | dynamic_decay: bool = False, |
53 | 53 | name: str = "MovingAverage", |
@@ -82,6 +82,14 @@ def __init__( |
82 | 82 | super().__init__(optimizer, sequential_update, name, **kwargs) |
83 | 83 | self._num_updates = num_updates |
84 | 84 | if self._num_updates is not None: |
| 85 | + if isinstance(self._num_updates, tf.Variable): |
| 86 | + tf.debugging.assert_integer( |
| 87 | + self._num_updates, |
| 88 | + ( |
| 89 | + 'type of argument "num_updates" must be ' |
| 90 | + "int; got {} instead".format(self._num_updates.dtype) |
| 91 | + ), |
| 92 | + ) |
85 | 93 | num_updates = tf.cast(self._num_updates, tf.float32, name="num_updates") |
86 | 94 | average_decay = tf.minimum( |
87 | 95 | average_decay, (1.0 + num_updates) / (10.0 + num_updates) |
|
0 commit comments