1313import random
1414import unittest
1515from typing import cast , List , Optional , Tuple
16+ from unittest .mock import Mock , patch
1617
1718import torch
1819import torch .distributed as dist
20+ from fbgemm_gpu .split_table_batched_embeddings_ops_training import SparseType
1921from hypothesis import given , settings , strategies as st , Verbosity
2022from torchrec .distributed .embedding_sharding import bucketize_kjt_before_all2all
2123from torchrec .distributed .embeddingbag import EmbeddingBagCollectionSharder
3335 ShardMetadata ,
3436)
3537from torchrec .distributed .utils import (
38+ _quantize_embedding_modules ,
3639 add_params_from_parameter_sharding ,
3740 convert_to_fbgemm_types ,
3841 get_bucket_metadata_from_shard_metadata ,
@@ -79,6 +82,7 @@ def test_get_unsharded_module_names(self) -> None:
7982 dense_device = device ,
8083 sparse_device = device ,
8184 )
85+
8286 dmp = DistributedModelParallel (
8387 module = m ,
8488 init_data_parallel = False ,
@@ -95,6 +99,229 @@ def test_get_unsharded_module_names(self) -> None:
9599 dist .destroy_process_group ()
96100
97101
102+ class QuantizeEmbeddingModulesTest (unittest .TestCase ):
103+ def test_quantize_embedding_modules (self ) -> None :
104+ """Test that _quantize_embedding_modules correctly converts embedding weight tensors."""
105+ # Create a mock embedding module that mimics SplitTableBatchedEmbeddingBagsCodegen
106+ mock_emb = Mock ()
107+
108+ # Create mock tensors that support the operations we need
109+ mock_weights_dev = Mock ()
110+ mock_weights_dev .dtype = torch .float32
111+ mock_weights_dev .to .return_value = Mock ()
112+ mock_weights_dev .to .return_value .dtype = torch .float16
113+ storage_mock_dev = Mock ()
114+ storage_mock_dev .resize_ = Mock ()
115+ mock_weights_dev .untyped_storage .return_value = storage_mock_dev
116+
117+ mock_weights_host = Mock ()
118+ mock_weights_host .dtype = torch .float32
119+ mock_weights_host .to .return_value = Mock ()
120+ mock_weights_host .to .return_value .dtype = torch .float16
121+ storage_mock_host = Mock ()
122+ storage_mock_host .resize_ = Mock ()
123+ mock_weights_host .untyped_storage .return_value = storage_mock_host
124+
125+ mock_weights_uvm = Mock ()
126+ mock_weights_uvm .dtype = torch .float32
127+ mock_weights_uvm .to .return_value = Mock ()
128+ mock_weights_uvm .to .return_value .dtype = torch .float16
129+ storage_mock_uvm = Mock ()
130+ storage_mock_uvm .resize_ = Mock ()
131+ mock_weights_uvm .untyped_storage .return_value = storage_mock_uvm
132+
133+ mock_emb .weights_dev = mock_weights_dev
134+ mock_emb .weights_host = mock_weights_host
135+ mock_emb .weights_uvm = mock_weights_uvm
136+ mock_emb .weights_precision = SparseType .FP32
137+
138+ # Create a module that contains the mock embedding
139+ module = torch .nn .Module ()
140+
141+ # Mock the _group_sharded_modules function to return our mock embedding
142+ with patch (
143+ "torchrec.distributed.utils._group_sharded_modules"
144+ ) as mock_group_sharded :
145+ mock_group_sharded .return_value = [mock_emb ]
146+
147+ # Mock the data_type_to_sparse_type function
148+ with patch (
149+ "torchrec.distributed.utils.data_type_to_sparse_type"
150+ ) as mock_convert :
151+ mock_sparse_type = Mock ()
152+ mock_sparse_type .as_dtype .return_value = torch .float16
153+ mock_convert .return_value = mock_sparse_type
154+
155+ # Mock the logger
156+ with patch ("torchrec.distributed.utils.logger" ) as mock_logger :
157+ # Call the function with FP16 data type
158+ _quantize_embedding_modules (module , DataType .FP16 )
159+
160+ # Verify that _group_sharded_modules was called with the module
161+ mock_group_sharded .assert_called_once_with (module )
162+
163+ # Verify that data_type_to_sparse_type was called with FP16
164+ mock_convert .assert_called_once_with (DataType .FP16 )
165+
166+ # Verify that logger.info was called with the expected message
167+ mock_logger .info .assert_called_once_with (
168+ f"convert embedding modules to converted_dtype={ DataType .FP16 .value } quantization"
169+ )
170+
171+ # Verify that .to() was called on each tensor with the correct dtype
172+ mock_weights_dev .to .assert_called_once_with (torch .float16 )
173+ mock_weights_host .to .assert_called_once_with (torch .float16 )
174+ mock_weights_uvm .to .assert_called_once_with (torch .float16 )
175+
176+ # Verify that the storage resize was called for each tensor
177+ storage_mock_dev .resize_ .assert_called_once_with (0 )
178+ storage_mock_host .resize_ .assert_called_once_with (0 )
179+ storage_mock_uvm .resize_ .assert_called_once_with (0 )
180+
181+ # Verify that weights_precision is correctly set to the converted sparse type
182+ self .assertEqual (mock_emb .weights_precision , mock_sparse_type )
183+
184+ def test_quantize_embedding_modules_no_sharded_modules (self ) -> None :
185+ """Test that _quantize_embedding_modules handles modules with no sharded embeddings."""
186+ # Create a module with no sharded embeddings
187+ module = torch .nn .Module ()
188+
189+ # Mock the _group_sharded_modules function to return empty list
190+ with patch (
191+ "torchrec.distributed.utils._group_sharded_modules"
192+ ) as mock_group_sharded :
193+ mock_group_sharded .return_value = []
194+
195+ # Mock the data_type_to_sparse_type function
196+ with patch (
197+ "torchrec.distributed.utils.data_type_to_sparse_type"
198+ ) as mock_convert :
199+ mock_sparse_type = Mock ()
200+ mock_convert .return_value = mock_sparse_type
201+
202+ # Mock the logger
203+ with patch ("torchrec.distributed.utils.logger" ) as mock_logger :
204+ # Call the function - should not raise any errors
205+ _quantize_embedding_modules (module , DataType .FP16 )
206+
207+ # Verify that _group_sharded_modules was called
208+ mock_group_sharded .assert_called_once_with (module )
209+
210+ # Verify that data_type_to_sparse_type was called
211+ mock_convert .assert_called_once_with (DataType .FP16 )
212+
213+ # Verify that logger.info was called
214+ mock_logger .info .assert_called_once ()
215+
216+ def test_quantize_embedding_modules_multiple_embeddings (self ) -> None :
217+ """Test that _quantize_embedding_modules handles multiple embedding modules."""
218+ # Create multiple mock embedding modules
219+ mock_emb1 = Mock ()
220+ mock_emb2 = Mock ()
221+
222+ # Create fully mocked tensors for first embedding
223+ mock_weights_dev1 = Mock ()
224+ mock_weights_dev1 .dtype = torch .float32
225+ mock_weights_dev1 .to .return_value = Mock ()
226+ mock_weights_dev1 .to .return_value .dtype = torch .int8
227+ storage_mock_dev1 = Mock ()
228+ storage_mock_dev1 .resize_ = Mock ()
229+ mock_weights_dev1 .untyped_storage .return_value = storage_mock_dev1
230+
231+ mock_weights_host1 = Mock ()
232+ mock_weights_host1 .dtype = torch .float32
233+ mock_weights_host1 .to .return_value = Mock ()
234+ mock_weights_host1 .to .return_value .dtype = torch .int8
235+ storage_mock_host1 = Mock ()
236+ storage_mock_host1 .resize_ = Mock ()
237+ mock_weights_host1 .untyped_storage .return_value = storage_mock_host1
238+
239+ mock_weights_uvm1 = Mock ()
240+ mock_weights_uvm1 .dtype = torch .float32
241+ mock_weights_uvm1 .to .return_value = Mock ()
242+ mock_weights_uvm1 .to .return_value .dtype = torch .int8
243+ storage_mock_uvm1 = Mock ()
244+ storage_mock_uvm1 .resize_ = Mock ()
245+ mock_weights_uvm1 .untyped_storage .return_value = storage_mock_uvm1
246+
247+ mock_emb1 .weights_dev = mock_weights_dev1
248+ mock_emb1 .weights_host = mock_weights_host1
249+ mock_emb1 .weights_uvm = mock_weights_uvm1
250+ mock_emb1 .weights_precision = SparseType .FP32
251+
252+ # Create fully mocked tensors for second embedding
253+ mock_weights_dev2 = Mock ()
254+ mock_weights_dev2 .dtype = torch .float32
255+ mock_weights_dev2 .to .return_value = Mock ()
256+ mock_weights_dev2 .to .return_value .dtype = torch .int8
257+ storage_mock_dev2 = Mock ()
258+ storage_mock_dev2 .resize_ = Mock ()
259+ mock_weights_dev2 .untyped_storage .return_value = storage_mock_dev2
260+
261+ mock_weights_host2 = Mock ()
262+ mock_weights_host2 .dtype = torch .float32
263+ mock_weights_host2 .to .return_value = Mock ()
264+ mock_weights_host2 .to .return_value .dtype = torch .int8
265+ storage_mock_host2 = Mock ()
266+ storage_mock_host2 .resize_ = Mock ()
267+ mock_weights_host2 .untyped_storage .return_value = storage_mock_host2
268+
269+ mock_weights_uvm2 = Mock ()
270+ mock_weights_uvm2 .dtype = torch .float32
271+ mock_weights_uvm2 .to .return_value = Mock ()
272+ mock_weights_uvm2 .to .return_value .dtype = torch .int8
273+ storage_mock_uvm2 = Mock ()
274+ storage_mock_uvm2 .resize_ = Mock ()
275+ mock_weights_uvm2 .untyped_storage .return_value = storage_mock_uvm2
276+
277+ mock_emb2 .weights_dev = mock_weights_dev2
278+ mock_emb2 .weights_host = mock_weights_host2
279+ mock_emb2 .weights_uvm = mock_weights_uvm2
280+ mock_emb2 .weights_precision = SparseType .FP32
281+
282+ # Create a module
283+ module = torch .nn .Module ()
284+
285+ # Mock the _group_sharded_modules function to return both mock embeddings
286+ with patch (
287+ "torchrec.distributed.utils._group_sharded_modules"
288+ ) as mock_group_sharded :
289+ mock_group_sharded .return_value = [mock_emb1 , mock_emb2 ]
290+
291+ # Mock the data_type_to_sparse_type function
292+ with patch (
293+ "torchrec.distributed.utils.data_type_to_sparse_type"
294+ ) as mock_convert :
295+ mock_sparse_type = Mock ()
296+ mock_sparse_type .as_dtype .return_value = torch .int8
297+ mock_convert .return_value = mock_sparse_type
298+
299+ # Call the function
300+ _quantize_embedding_modules (module , DataType .INT8 )
301+
302+ # Verify that .to() was called on each tensor with the correct dtype
303+ mock_weights_dev1 .to .assert_called_once_with (torch .int8 )
304+ mock_weights_host1 .to .assert_called_once_with (torch .int8 )
305+ mock_weights_uvm1 .to .assert_called_once_with (torch .int8 )
306+
307+ mock_weights_dev2 .to .assert_called_once_with (torch .int8 )
308+ mock_weights_host2 .to .assert_called_once_with (torch .int8 )
309+ mock_weights_uvm2 .to .assert_called_once_with (torch .int8 )
310+
311+ # Verify that the storage resize was called for each tensor
312+ storage_mock_dev1 .resize_ .assert_called_once_with (0 )
313+ storage_mock_host1 .resize_ .assert_called_once_with (0 )
314+ storage_mock_uvm1 .resize_ .assert_called_once_with (0 )
315+
316+ storage_mock_dev2 .resize_ .assert_called_once_with (0 )
317+ storage_mock_host2 .resize_ .assert_called_once_with (0 )
318+ storage_mock_uvm2 .resize_ .assert_called_once_with (0 )
319+
320+ # Verify that weights_precision is correctly set to the converted sparse type
321+ self .assertEqual (mock_emb1 .weights_precision , mock_sparse_type )
322+ self .assertEqual (mock_emb2 .weights_precision , mock_sparse_type )
323+
324+
98325def _compute_translated_lengths (
99326 row_indices : List [int ],
100327 indices_offsets : List [int ],
0 commit comments