Skip to content

Commit 39ac369

Browse files
Support saving and loading sharded optimizer variables into checkpoints across arbitrary sharding configurations.
Because optimizer variables are initialized from the model's individual shard variables using `colocate_with`, they are not aggregated into a `ShardedVariable` upon creation. This change accumulates and creates `ShardedVariable` objects for optimizer variables automatically as they are created from corresponding model variables, when those model variables are shards. Optimizer attribute variables (e.g. Adam's `self._momentums`) are kept as regular Variables to maintain existing functionality. In order to save and restore optimizer attribute variables correctly across varying shard configurations, they are replaced with their `ShardedVariable` containers in the checkpointing object graph. PiperOrigin-RevId: 590322532
1 parent f0fb8b4 commit 39ac369

File tree

3 files changed

+210
-3
lines changed

3 files changed

+210
-3
lines changed

tf_keras/optimizers/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ distribute_py_test(
104104
name = "optimizer_pss_test",
105105
size = "medium",
106106
srcs = ["optimizer_pss_test.py"],
107-
shard_count = 32,
107+
shard_count = 50,
108108
tags = [
109109
"multi_gpu",
110110
"no_oss",

tf_keras/optimizers/optimizer.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def __init__(
102102
)
103103

104104
self._variables = []
105+
# A dict mapping a model ShardedVariable id to an object that builds a
106+
# ShardedVariable from the corresponding optimizer variables. See
107+
# `add_variable_from_reference`.
108+
self._sharded_variable_builders = self._no_dependency({})
105109
self._create_iteration_variable()
106110
self._process_kwargs(kwargs)
107111

@@ -516,9 +520,76 @@ def add_variable_from_reference(
516520
dtype=model_variable.dtype,
517521
trainable=False,
518522
)
519-
self._variables.append(variable)
523+
# If model_variable is a shard of a ShardedVariable, we should add a
524+
# ShardedVariable for all related optimizer variables so that
525+
# checkpointing is robust to different partitionings. Use unique_id to
526+
# dedup ShardedVariables.
527+
if hasattr(model_variable, "_sharded_container"):
528+
sharded_variable = model_variable._sharded_container()
529+
# Get or create builder object
530+
sv_builder = self._sharded_variable_builders.setdefault(
531+
(sharded_variable._unique_id, variable_name),
532+
_ShardedVariableBuilder(len(sharded_variable.variables)),
533+
)
534+
sv_builder.add_shard(variable)
535+
if sv_builder.has_all_shards():
536+
self._variables.append(sv_builder.build())
537+
else:
538+
self._variables.append(variable)
520539
return variable
521540

541+
def _trackable_children(self, save_type="checkpoint", **kwargs):
542+
"""Override in order to coalesce and track `ShardedVariable`s.
543+
544+
If an optimizer variable's corresponding model variable is a shard of a
545+
larger `ShardedVariable`, then we track the optimizer variable in
546+
`self._variables` as a `ShardedVariable` via the logic in
547+
`add_variable_from_reference`. However, most optimizer implementations
548+
additionally keep their variables as attributes, which will be tracked
549+
via `AutoTrackable` functionality and not accumulated into
550+
`ShardedVariable`s.
551+
552+
So, to enable restoration of these attributes in possibly different
553+
sharding configurations, we should save them as `ShardedVariable`s.
554+
Here, any optimizer attributes that are variable shards of a larger
555+
`ShardedVariable` are here replaced by the `ShardedVariable` itself,
556+
which was created in `add_variable_from_reference`.
557+
558+
All non-sharded variables are kept as-is. If none of the model variables
559+
are sharded, this reduces to `AutoTrackable._trackable_children()`.
560+
"""
561+
# Due to object-identity based matching logic in checkpointing, new
562+
# python objects should not be created on each call to
563+
# `_trackable_children`. So instead, only coalesce if not done before.
564+
if not hasattr(self, "_coalesced_children"):
565+
# This new attribute should not be tracked to avoid infinite
566+
# recursion, so wrap in NoDependency
567+
self._coalesced_children = self._no_dependency({})
568+
children = super()._trackable_children(save_type, **kwargs)
569+
for key, val in children.items():
570+
if key not in [
571+
"_variables",
572+
"_index_dict",
573+
"_learning_rate",
574+
"_iterations",
575+
]:
576+
new_val = val
577+
if isinstance(val, list):
578+
# TODO(jmullenbach): handle arbitrary nesting
579+
sv_vals = []
580+
for var in val:
581+
if hasattr(var, "_sharded_container"):
582+
sv = var._sharded_container()
583+
if sv not in sv_vals:
584+
sv_vals.append(sv)
585+
else:
586+
sv_vals.append(var)
587+
new_val = tf.__internal__.tracking.wrap(sv_vals)
588+
self._coalesced_children[key] = new_val
589+
else:
590+
self._coalesced_children[key] = val
591+
return self._coalesced_children
592+
522593
def minimize(self, loss, var_list, tape=None):
523594
"""Minimize `loss` by updating `var_list`.
524595
@@ -1384,6 +1455,30 @@ def __call__(self):
13841455
return self
13851456

13861457

1458+
class _ShardedVariableBuilder:
1459+
"""Accumulate variable shards into a `ShardedVariable`."""
1460+
1461+
def __init__(self, num_shards):
1462+
self.shards = [None] * num_shards
1463+
1464+
def add_shard(self, shard):
1465+
# Get shard index from name
1466+
shard_idx = int(shard.name.split("part_")[-1].split(":")[0])
1467+
if self.shards[shard_idx] is None:
1468+
self.shards[shard_idx] = shard
1469+
else:
1470+
raise ValueError(
1471+
"Cannot add duplicate optimizer variable from "
1472+
f"shard variable {shard.name}"
1473+
)
1474+
1475+
def has_all_shards(self):
1476+
return all([shard is not None for shard in self.shards])
1477+
1478+
def build(self):
1479+
return tf.__internal__.distribute.ShardedVariable(self.shards)
1480+
1481+
13871482
# Register the optimizer for loading from saved_model purpose.
13881483
# When `keras_2` is installed in same env, it raises assertion for duplicate
13891484
# registration with same name. Rename the symbol in this case.

tf_keras/optimizers/optimizer_pss_test.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Tests for calling optimizer on ParameterServerStrategy."""
22

3+
import os
4+
35
import tensorflow.compat.v2 as tf
46
from absl.testing import parameterized
57

@@ -96,7 +98,11 @@ def _verify_accumulators_updated(self, optimizer):
9698
if "iteration" not in var.name and "learning_rate" not in var.name:
9799
# Find a variable not iteration or learning_rate, and verify its
98100
# value is updated (not 0).
99-
self.assertNotAllEqual(var, 0)
101+
if isinstance(var, tf.__internal__.distribute.ShardedVariable):
102+
for shard in var.variables:
103+
self.assertNotAllEqual(shard, 0)
104+
else:
105+
self.assertNotAllEqual(var, 0)
100106

101107
@ds_combinations.generate(
102108
tf.__internal__.test.combinations.combine(
@@ -160,6 +166,112 @@ def replica_fn(data):
160166
self.assertEqual(self.evaluate(optimizer.iterations), 3)
161167
self._verify_accumulators_updated(optimizer)
162168

169+
@ds_combinations.generate(
170+
tf.__internal__.test.combinations.combine(
171+
strategy=STRATEGIES,
172+
shard_config=[
173+
[2, 2],
174+
[2, 3],
175+
[3, 2],
176+
[2, 1],
177+
[1, 1],
178+
[1, 2],
179+
[1, 3],
180+
],
181+
)
182+
)
183+
def testCheckpointShardedVariable(self, strategy, shard_config):
184+
# Data are embedding indices near shard boundaries for 2 or 3 shards
185+
test_indices = [33, 34, 49, 50, 66, 67]
186+
187+
def dataset_fn(_):
188+
x, y = [[index] for index in test_indices], [1, 1, 1, 0, 0, 0]
189+
ds = tf.data.Dataset.from_tensor_slices((x, y))
190+
ds = ds.repeat().batch(6)
191+
return ds
192+
193+
vocab_size = 100
194+
embed_dim = 32
195+
196+
def get_model():
197+
return keras.Sequential(
198+
[
199+
keras.layers.Embedding(vocab_size, embed_dim),
200+
keras.layers.Dense(1, activation="sigmoid"),
201+
]
202+
)
203+
204+
# Override partitioning
205+
if shard_config[0] == 1:
206+
strategy._extended._variable_partitioner = None
207+
else:
208+
strategy._extended._variable_partitioner = (
209+
tf.distribute.experimental.partitioners.FixedShardsPartitioner(
210+
shard_config[0]
211+
)
212+
)
213+
214+
# Create model and optimizer
215+
with strategy.scope():
216+
model = get_model()
217+
optimizer = adam.Adam(0.002)
218+
219+
model.compile(loss="mse", optimizer=optimizer)
220+
221+
model.build(input_shape=(None, 1))
222+
model.optimizer.build(model.trainable_variables)
223+
224+
ds = dataset_creator.DatasetCreator(dataset_fn)
225+
# Train a bit to update optimizer variables
226+
model.fit(ds, epochs=1, steps_per_epoch=5)
227+
228+
self._verify_accumulators_updated(optimizer)
229+
230+
# Extract optimizer variables to later check they restore properly
231+
pre_ckpt_optimizer_values = []
232+
for var in model.optimizer.variables:
233+
# Just check the embedding variables
234+
if var.shape == [vocab_size, embed_dim]:
235+
for index in test_indices:
236+
pre_ckpt_optimizer_values.append(var[index])
237+
# Adam has 2 slot variables, momentum and velocity
238+
self.assertLen(pre_ckpt_optimizer_values, 2 * len(test_indices))
239+
240+
checkpoint_path = os.path.join(self.get_temp_dir(), "model_weights")
241+
model.save_weights(checkpoint_path)
242+
243+
# Create new model under different sharding and load checkpoint
244+
if shard_config[1] == 1:
245+
strategy._extended._variable_partitioner = None
246+
else:
247+
strategy._extended._variable_partitioner = (
248+
tf.distribute.experimental.partitioners.FixedShardsPartitioner(
249+
shard_config[1]
250+
)
251+
)
252+
with strategy.scope():
253+
model_2 = get_model()
254+
optimizer_2 = adam.Adam(0.002)
255+
model_2.compile(loss="mse", optimizer=optimizer_2)
256+
model_2.build(input_shape=(None, 1))
257+
model_2.optimizer.build(model_2.trainable_variables)
258+
model_2.load_weights(checkpoint_path)
259+
260+
post_ckpt_optimizer_values = []
261+
for var in model_2.optimizer.variables:
262+
if var.shape == [vocab_size, embed_dim]:
263+
for index in test_indices:
264+
post_ckpt_optimizer_values.append(var[index])
265+
self.assertLen(post_ckpt_optimizer_values, 2 * len(test_indices))
266+
for pre_val, post_val in zip(
267+
pre_ckpt_optimizer_values, post_ckpt_optimizer_values
268+
):
269+
self.assertAllEqual(pre_val, post_val)
270+
271+
# Confirm training still functional
272+
ds = dataset_creator.DatasetCreator(dataset_fn)
273+
model_2.fit(ds, epochs=1, steps_per_epoch=5)
274+
163275

164276
if __name__ == "__main__":
165277
tf.__internal__.distribute.multi_process_runner.test_main()

0 commit comments

Comments
 (0)