Skip to content

Commit 7c36879

Browse files
authored
Fix handling of mixed placement in DistributedEmbedding. (#136)
The transformation from the deeply nested input structure to the intermediate representation of "placement -> path -> tensor" and the reverse transformation for the outputs were incorrectly based on flattening and packing as `self._placement_to_path_to_feature_config`. This would use an incorrect order for the flat structure and result in tensors at the wrong place. This fix makes use of `self._placement_to_path_to_feature_config` instead, which can correctly do the mapping. Also changed `PlacementAndPath` to be a `dataclass` so that `keras.tree` considers those atoms and doesn't recurse into them. This simplifies some of the `map_structure{_up_to}` calls. Fixes #134
1 parent e002b4b commit 7c36879

File tree

2 files changed

+122
-38
lines changed

2 files changed

+122
-38
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import dataclasses
23
import importlib.util
34
import typing
45
from typing import Any, Sequence
@@ -20,9 +21,10 @@
2021
SUPPORTED_PLACEMENTS = ("auto", "default_device", "sparsecore")
2122

2223

23-
PlacementAndPath = collections.namedtuple(
24-
"PlacementAndPath", ["placement", "path"]
25-
)
24+
@dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
25+
class PlacementAndPath:
26+
placement: str
27+
path: str
2628

2729

2830
def _ragged_to_dense_inputs(
@@ -518,12 +520,12 @@ def _init_feature_configs_structures(
518520
With these structures in place, the steps to:
519521
- go from the deeply nested structure to the two-level structure are:
520522
- `assert_same_struct` as `self._feature_configs`
521-
- `flatten`
522-
- `pack_sequence_as` `self._placement_to_path_to_feature_config`
523+
- use `self._feature_deeply_nested_placement_and_paths` to map from
524+
deeply nested to two-level
523525
- go from the two-level structure to the deeply nested structure:
524-
- `assert_same_struct` as `self._placement_to_path_to_feature_config`
525-
- `flatten`
526-
- `pack_sequence_as` `self._feature_configs`
526+
- `assert_same_struct` as `self._placement_to_path_to_feature_config`
527+
- use `self._feature_deeply_nested_placement_and_paths` to locate each
528+
output in the two-level dicts
527529
528530
Args:
529531
feature_configs: The deeply nested structure of `FeatureConfig` or
@@ -590,14 +592,14 @@ def build(self, input_shapes: types.Nested[types.Shape]) -> None:
590592
] = collections.defaultdict(dict)
591593

592594
def populate_placement_to_path_to_input_shape(
593-
placement_and_path: PlacementAndPath, input_shape: types.Shape
595+
pp: PlacementAndPath, input_shape: types.Shape
594596
) -> None:
595-
placement_to_path_to_input_shape[placement_and_path.placement][
596-
placement_and_path.path
597-
] = input_shape
597+
placement_to_path_to_input_shape[pp.placement][pp.path] = (
598+
input_shape
599+
)
598600

599601
keras.tree.map_structure_up_to(
600-
self._feature_configs,
602+
self._feature_deeply_nested_placement_and_paths,
601603
populate_placement_to_path_to_input_shape,
602604
self._feature_deeply_nested_placement_and_paths,
603605
input_shapes,
@@ -645,35 +647,40 @@ def preprocess(
645647
"""
646648
# Verify input structure.
647649
keras.tree.assert_same_structure(self._feature_configs, inputs)
650+
if weights is not None:
651+
keras.tree.assert_same_structure(self._feature_configs, weights)
648652

649653
if not self.built:
650-
input_shapes = keras.tree.map_structure_up_to(
651-
self._feature_configs,
654+
input_shapes = keras.tree.map_structure(
652655
lambda array: backend.standardize_shape(array.shape),
653656
inputs,
654657
)
655658
self.build(input_shapes)
656659

657-
# Go from deeply nested structure of inputs to flat inputs.
658-
flat_inputs = keras.tree.flatten(inputs)
660+
# Go from deeply nested to nested dict placement -> path -> input.
661+
def to_placement_to_path(
662+
tensors: types.Nested[types.Tensor],
663+
) -> dict[str, dict[str, types.Tensor]]:
664+
result: dict[str, dict[str, types.Tensor]] = {
665+
p: dict() for p in self._placement_to_path_to_feature_config
666+
}
659667

660-
# Go from flat to nested dict placement -> path -> input.
661-
placement_to_path_to_inputs = keras.tree.pack_sequence_as(
662-
self._placement_to_path_to_feature_config, flat_inputs
663-
)
668+
def populate(pp: PlacementAndPath, x: types.Tensor) -> None:
669+
result[pp.placement][pp.path] = x
664670

665-
if weights is not None:
666-
# Same for weights if present.
667-
keras.tree.assert_same_structure(self._feature_configs, weights)
668-
flat_weights = keras.tree.flatten(weights)
669-
placement_to_path_to_weights = keras.tree.pack_sequence_as(
670-
self._placement_to_path_to_feature_config, flat_weights
671+
keras.tree.map_structure(
672+
populate,
673+
self._feature_deeply_nested_placement_and_paths,
674+
tensors,
671675
)
672-
else:
673-
# Populate keys for weights.
674-
placement_to_path_to_weights = {
675-
k: None for k in placement_to_path_to_inputs
676-
}
676+
return result
677+
678+
placement_to_path_to_inputs = to_placement_to_path(inputs)
679+
680+
# Same for weights if present.
681+
placement_to_path_to_weights = (
682+
to_placement_to_path(weights) if weights is not None else None
683+
)
677684

678685
placement_to_path_to_preprocessed: dict[
679686
str, dict[str, dict[str, types.Nested[types.Tensor]]]
@@ -684,7 +691,9 @@ def preprocess(
684691
placement_to_path_to_preprocessed["sparsecore"] = (
685692
self._sparsecore_preprocess(
686693
placement_to_path_to_inputs["sparsecore"],
687-
placement_to_path_to_weights["sparsecore"],
694+
placement_to_path_to_weights["sparsecore"]
695+
if placement_to_path_to_weights is not None
696+
else None,
688697
training,
689698
)
690699
)
@@ -694,7 +703,9 @@ def preprocess(
694703
placement_to_path_to_preprocessed["default_device"] = (
695704
self._default_device_preprocess(
696705
placement_to_path_to_inputs["default_device"],
697-
placement_to_path_to_weights["default_device"],
706+
placement_to_path_to_weights["default_device"]
707+
if placement_to_path_to_weights is not None
708+
else None,
698709
training,
699710
)
700711
)
@@ -780,11 +791,13 @@ def call(
780791
placement_to_path_to_outputs,
781792
)
782793

783-
# Go from placement -> path -> output to flat outputs.
784-
flat_outputs = keras.tree.flatten(placement_to_path_to_outputs)
794+
# Go from placement -> path -> output to deeply nested structure.
795+
def populate_output(pp: PlacementAndPath) -> types.Tensor:
796+
return placement_to_path_to_outputs[pp.placement][pp.path]
785797

786-
# Go from flat outputs to deeply nested structure.
787-
return keras.tree.pack_sequence_as(self._feature_configs, flat_outputs)
798+
return keras.tree.map_structure(
799+
populate_output, self._feature_deeply_nested_placement_and_paths
800+
)
788801

789802
def get_embedding_tables(self) -> dict[str, types.Tensor]:
790803
"""Return the content of the embedding tables by table name.

keras_rs/src/layers/embedding/distributed_embedding_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,77 @@ def test_shared_table(self):
685685
res["feature3"].shape, (batch_size, EMBEDDING_OUTPUT_DIM)
686686
)
687687

688+
def test_mixed_placement(self):
689+
if not self.on_tpu:
690+
self.skipTest("Mixed placement test requires a TPU.")
691+
692+
# Use different embedding dimensions to verify that the correct tables
693+
# are used for each feature.
694+
embedding_output_dim1 = 16
695+
embedding_output_dim2 = 32
696+
embedding_output_dim3 = 64
697+
698+
# Intermix placement to exercise the change of order of inputs.
699+
table1 = config.TableConfig(
700+
name="table1",
701+
vocabulary_size=VOCABULARY_SIZE,
702+
embedding_dim=embedding_output_dim1,
703+
placement="default_device",
704+
)
705+
table2 = config.TableConfig(
706+
name="table2",
707+
vocabulary_size=VOCABULARY_SIZE,
708+
embedding_dim=embedding_output_dim2,
709+
placement="sparsecore",
710+
)
711+
table3 = config.TableConfig(
712+
name="table3",
713+
vocabulary_size=VOCABULARY_SIZE,
714+
embedding_dim=embedding_output_dim3,
715+
placement="default_device",
716+
)
717+
718+
embedding_config = {
719+
"feature1": config.FeatureConfig(
720+
name="feature1",
721+
table=table1,
722+
input_shape=(BATCH_SIZE_PER_CORE, 1),
723+
output_shape=(BATCH_SIZE_PER_CORE, embedding_output_dim1),
724+
),
725+
"feature2": config.FeatureConfig(
726+
name="feature2",
727+
table=table2,
728+
input_shape=(BATCH_SIZE_PER_CORE, 1),
729+
output_shape=(BATCH_SIZE_PER_CORE, embedding_output_dim2),
730+
),
731+
"feature3": config.FeatureConfig(
732+
name="feature3",
733+
table=table3,
734+
input_shape=(BATCH_SIZE_PER_CORE, 1),
735+
output_shape=(BATCH_SIZE_PER_CORE, embedding_output_dim3),
736+
),
737+
}
738+
739+
batch_size = self._strategy.num_replicas_in_sync * BATCH_SIZE_PER_CORE
740+
inputs, _, _ = self.create_inputs_weights_and_labels(
741+
batch_size, "dense", embedding_config
742+
)
743+
744+
with self._strategy.scope():
745+
layer = distributed_embedding.DistributedEmbedding(embedding_config)
746+
747+
res = self.run_with_strategy(layer.__call__, inputs)
748+
749+
self.assertEqual(
750+
res["feature1"].shape, (batch_size, embedding_output_dim1)
751+
)
752+
self.assertEqual(
753+
res["feature2"].shape, (batch_size, embedding_output_dim2)
754+
)
755+
self.assertEqual(
756+
res["feature3"].shape, (batch_size, embedding_output_dim3)
757+
)
758+
688759
def test_save_load_model(self):
689760
batch_size = self._strategy.num_replicas_in_sync * BATCH_SIZE_PER_CORE
690761
feature_configs = self.get_embedding_config("dense", self.placement)

0 commit comments

Comments
 (0)