1010import copy
1111import enum
1212import unittest
13+ from typing import Tuple , Union
1314from unittest .mock import MagicMock
1415
1516import torch
1617
1718from torchrec .distributed .embedding_types import EmbeddingComputeKernel
18- from torchrec .distributed .test_utils .test_model import ModelInput , TestNegSamplingModule
19+ from torchrec .distributed .test_utils .test_model import (
20+ ModelInput ,
21+ TestNegSamplingModule ,
22+ TestSparseNN ,
23+ )
1924from torchrec .distributed .train_pipeline .pipeline_context import TrainPipelineContext
2025from torchrec .distributed .train_pipeline .runtime_forwards import PipelinedForward
21-
2226from torchrec .distributed .train_pipeline .tests .test_train_pipelines_base import (
2327 TrainPipelineSparseDistTestBase ,
2428)
25- from torchrec .distributed .train_pipeline .tracing import (
26- ArgInfo ,
27- ArgInfoStepFactory ,
28- CallArgs ,
29- NodeArgsHelper ,
30- PipelinedPostproc ,
31- )
29+ from torchrec .distributed .train_pipeline .tracing import CallArgs , PipelinedPostproc
3230from torchrec .distributed .train_pipeline .utils import _rewrite_model
3331from torchrec .distributed .types import ShardingType
3432from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
@@ -40,6 +38,15 @@ class ModelType(enum.Enum):
4038 PIPELINED = "pipelined"
4139
4240
41+ @torch .fx .wrap
42+ def enrich_hstu_features (
43+ kjt : KeyedJaggedTensor , hstu_factor : float
44+ ) -> KeyedJaggedTensor :
45+ if kjt ._weights is not None :
46+ kjt ._weights *= hstu_factor
47+ return kjt
48+
49+
4350class TrainPipelineUtilsTest (TrainPipelineSparseDistTestBase ):
4451 # pyre-fixme[56]: Pyre was not able to infer the type of argument
4552 @unittest .skipIf (
@@ -257,3 +264,115 @@ def test_restore_from_snapshot(self) -> None:
257264 ]
258265 for source_model_type , recipient_model_type in variants :
259266 self ._test_restore_from_snapshot (source_model_type , recipient_model_type )
267+
268+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
269+ @unittest .skipIf (
270+ not torch .cuda .is_available (),
271+ "Not enough GPUs, this test requires at least one GPU" ,
272+ )
273+ def test_rewrite_model_with_fx_wrap (self ) -> None :
274+ sharding_type = ShardingType .TABLE_WISE .value
275+ kernel_type = EmbeddingComputeKernel .FUSED .value
276+ fused_params = {}
277+
278+ class TestPostProcModule (torch .nn .Module ):
279+ def __init__ (self , f : float ):
280+ super ().__init__ ()
281+ self .f = f
282+
283+ def forward (self , x : KeyedJaggedTensor ) -> KeyedJaggedTensor :
284+ return enrich_hstu_features (x , self .f )
285+
286+ postproc_module = TestPostProcModule (0.3 )
287+
288+ class TestModel (TestSparseNN ):
289+ use_postproc_module : bool = False
290+
291+ def forward (
292+ self ,
293+ input : ModelInput ,
294+ ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
295+ if (type (self )).use_postproc_module :
296+ input = self .postproc_module (input )
297+ else :
298+ input = enrich_hstu_features (input , 0.3 )
299+ return self .dense_forward (input , self .sparse_forward (input ))
300+
301+ model = TestModel (
302+ tables = self .tables ,
303+ weighted_tables = self .weighted_tables ,
304+ dense_device = self .device ,
305+ sparse_device = torch .device ("meta" ),
306+ postproc_module = postproc_module ,
307+ )
308+
309+ sharded_model , optim = self ._generate_sharded_model_and_optimizer (
310+ model , sharding_type , kernel_type , fused_params
311+ )
312+
313+ # Try to rewrite model using a function for postproc
314+ # EBC forwards not overwritten to PipelinedForward due to KJT modification
315+ self .assertFalse (model .use_postproc_module )
316+ _rewrite_model (
317+ model = sharded_model ,
318+ batch = None ,
319+ context = TrainPipelineContext (),
320+ dist_stream = None ,
321+ pipeline_postproc = True ,
322+ )
323+ self .assertNotIsInstance (
324+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
325+ # `sparse`.
326+ sharded_model .module .sparse .ebc .forward ,
327+ PipelinedForward ,
328+ )
329+ self .assertNotIsInstance (
330+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
331+ # `sparse`.
332+ sharded_model .module .sparse .weighted_ebc .forward ,
333+ PipelinedForward ,
334+ )
335+
336+ # Now use postproc module
337+ TestModel .use_postproc_module = True
338+ self .assertTrue (model .use_postproc_module )
339+ _rewrite_model (
340+ model = sharded_model ,
341+ batch = None ,
342+ context = TrainPipelineContext (),
343+ dist_stream = None ,
344+ pipeline_postproc = True ,
345+ )
346+
347+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `sparse`.
348+ self .assertIsInstance (sharded_model .module .sparse .ebc .forward , PipelinedForward )
349+ self .assertIsInstance (
350+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
351+ # `sparse`.
352+ sharded_model .module .sparse .weighted_ebc .forward ,
353+ PipelinedForward ,
354+ )
355+ self .assertEqual (
356+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
357+ # `sparse`.
358+ sharded_model .module .sparse .ebc .forward ._args .args [0 ]
359+ .steps [0 ]
360+ .postproc_module ,
361+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
362+ # `postproc_module`.
363+ sharded_model .module .postproc_module ,
364+ )
365+ self .assertEqual (
366+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
367+ # `sparse`.
368+ sharded_model .module .sparse .weighted_ebc .forward ._args .args [0 ]
369+ .steps [0 ]
370+ .postproc_module ,
371+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
372+ # `postproc_module`.
373+ sharded_model .module .postproc_module ,
374+ )
375+ state_dict = sharded_model .state_dict ()
376+ missing_keys , unexpected_keys = sharded_model .load_state_dict (state_dict )
377+ self .assertEqual (missing_keys , [])
378+ self .assertEqual (unexpected_keys , [])
0 commit comments