11import collections
2+ import dataclasses
23import importlib .util
34import typing
45from typing import Any , Sequence
2021SUPPORTED_PLACEMENTS = ("auto" , "default_device" , "sparsecore" )
2122
2223
23- PlacementAndPath = collections .namedtuple (
24- "PlacementAndPath" , ["placement" , "path" ]
25- )
24+ @dataclasses .dataclass (eq = True , unsafe_hash = True , order = True )
25+ class PlacementAndPath :
26+ placement : str
27+ path : str
2628
2729
2830def _ragged_to_dense_inputs (
@@ -518,12 +520,12 @@ def _init_feature_configs_structures(
518520 With these structures in place, the steps to:
519521 - go from the deeply nested structure to the two-level structure are:
520522 - `assert_same_struct` as `self._feature_configs`
521- - `flatten`
522- - `pack_sequence_as` `self._placement_to_path_to_feature_config`
523+ - use `self._feature_deeply_nested_placement_and_paths` to map from
524+ deeply nested to two-level
523525 - go from the two-level structure to the deeply nested structure:
524- - `assert_same_struct` as `self._placement_to_path_to_feature_config`
525- - `flatten`
526- - `pack_sequence_as` `self._feature_configs`
526+ - `assert_same_struct` as `self._placement_to_path_to_feature_config`
527+ - use `self._feature_deeply_nested_placement_and_paths` to locate each
528+ output in the two-level dicts
527529
528530 Args:
529531 feature_configs: The deeply nested structure of `FeatureConfig` or
@@ -590,14 +592,14 @@ def build(self, input_shapes: types.Nested[types.Shape]) -> None:
590592 ] = collections .defaultdict (dict )
591593
592594 def populate_placement_to_path_to_input_shape (
593- placement_and_path : PlacementAndPath , input_shape : types .Shape
595+ pp : PlacementAndPath , input_shape : types .Shape
594596 ) -> None :
595- placement_to_path_to_input_shape [placement_and_path .placement ][
596- placement_and_path . path
597- ] = input_shape
597+ placement_to_path_to_input_shape [pp .placement ][pp . path ] = (
598+ input_shape
599+ )
598600
599601 keras .tree .map_structure_up_to (
600- self ._feature_configs ,
602+ self ._feature_deeply_nested_placement_and_paths ,
601603 populate_placement_to_path_to_input_shape ,
602604 self ._feature_deeply_nested_placement_and_paths ,
603605 input_shapes ,
@@ -645,35 +647,40 @@ def preprocess(
645647 """
646648 # Verify input structure.
647649 keras .tree .assert_same_structure (self ._feature_configs , inputs )
650+ if weights is not None :
651+ keras .tree .assert_same_structure (self ._feature_configs , weights )
648652
649653 if not self .built :
650- input_shapes = keras .tree .map_structure_up_to (
651- self ._feature_configs ,
654+ input_shapes = keras .tree .map_structure (
652655 lambda array : backend .standardize_shape (array .shape ),
653656 inputs ,
654657 )
655658 self .build (input_shapes )
656659
657- # Go from deeply nested structure of inputs to flat inputs.
658- flat_inputs = keras .tree .flatten (inputs )
660+ # Go from deeply nested to nested dict placement -> path -> input.
661+ def to_placement_to_path (
662+ tensors : types .Nested [types .Tensor ],
663+ ) -> dict [str , dict [str , types .Tensor ]]:
664+ result : dict [str , dict [str , types .Tensor ]] = {
665+ p : dict () for p in self ._placement_to_path_to_feature_config
666+ }
659667
660- # Go from flat to nested dict placement -> path -> input.
661- placement_to_path_to_inputs = keras .tree .pack_sequence_as (
662- self ._placement_to_path_to_feature_config , flat_inputs
663- )
668+ def populate (pp : PlacementAndPath , x : types .Tensor ) -> None :
669+ result [pp .placement ][pp .path ] = x
664670
665- if weights is not None :
666- # Same for weights if present.
667- keras .tree .assert_same_structure (self ._feature_configs , weights )
668- flat_weights = keras .tree .flatten (weights )
669- placement_to_path_to_weights = keras .tree .pack_sequence_as (
670- self ._placement_to_path_to_feature_config , flat_weights
671+ keras .tree .map_structure (
672+ populate ,
673+ self ._feature_deeply_nested_placement_and_paths ,
674+ tensors ,
671675 )
672- else :
673- # Populate keys for weights.
674- placement_to_path_to_weights = {
675- k : None for k in placement_to_path_to_inputs
676- }
676+ return result
677+
678+ placement_to_path_to_inputs = to_placement_to_path (inputs )
679+
680+ # Same for weights if present.
681+ placement_to_path_to_weights = (
682+ to_placement_to_path (weights ) if weights is not None else None
683+ )
677684
678685 placement_to_path_to_preprocessed : dict [
679686 str , dict [str , dict [str , types .Nested [types .Tensor ]]]
@@ -684,7 +691,9 @@ def preprocess(
684691 placement_to_path_to_preprocessed ["sparsecore" ] = (
685692 self ._sparsecore_preprocess (
686693 placement_to_path_to_inputs ["sparsecore" ],
687- placement_to_path_to_weights ["sparsecore" ],
694+ placement_to_path_to_weights ["sparsecore" ]
695+ if placement_to_path_to_weights is not None
696+ else None ,
688697 training ,
689698 )
690699 )
@@ -694,7 +703,9 @@ def preprocess(
694703 placement_to_path_to_preprocessed ["default_device" ] = (
695704 self ._default_device_preprocess (
696705 placement_to_path_to_inputs ["default_device" ],
697- placement_to_path_to_weights ["default_device" ],
706+ placement_to_path_to_weights ["default_device" ]
707+ if placement_to_path_to_weights is not None
708+ else None ,
698709 training ,
699710 )
700711 )
@@ -780,11 +791,13 @@ def call(
780791 placement_to_path_to_outputs ,
781792 )
782793
783- # Go from placement -> path -> output to flat outputs.
784- flat_outputs = keras .tree .flatten (placement_to_path_to_outputs )
794+ # Go from placement -> path -> output to deeply nested structure.
795+ def populate_output (pp : PlacementAndPath ) -> types .Tensor :
796+ return placement_to_path_to_outputs [pp .placement ][pp .path ]
785797
786- # Go from flat outputs to deeply nested structure.
787- return keras .tree .pack_sequence_as (self ._feature_configs , flat_outputs )
798+ return keras .tree .map_structure (
799+ populate_output , self ._feature_deeply_nested_placement_and_paths
800+ )
788801
789802 def get_embedding_tables (self ) -> dict [str , types .Tensor ]:
790803 """Return the content of the embedding tables by table name.
0 commit comments