@@ -119,6 +119,45 @@ def test_bf16_stochastic_round(self, device, compile):
119119 # must cast BF16 tensor back to FP32 so that .mean() is accurate
120120 torch .testing .assert_close (x_rep_bf16 .float ().mean (1 ), x , atol = 3e-5 , rtol = 3e-5 )
121121
122+ @parametrize ("device" , _DEVICES )
123+ @parametrize ("compile" , [False , True ])
124+ def test_bf16_stochastic_round_dtensor_cpu (self , device , compile ):
125+ pytest .importorskip ("torch.distributed" )
126+ import torch .distributed as dist
127+ from torch .distributed .tensor import DTensor , Replicate
128+ from torch .distributed .device_mesh import init_device_mesh
129+
130+ created_pg = False
131+ if dist .is_available () and not dist .is_initialized ():
132+ store = dist .TCPStore ("127.0.0.1" , 29500 , 1 , True )
133+ dist .init_process_group (
134+ backend = "gloo" ,
135+ store = store ,
136+ rank = 0 ,
137+ world_size = 1 ,
138+ )
139+ created_pg = True
140+
141+ try :
142+ torch .manual_seed (common_utils .SEED )
143+ x = torch .rand (32 , device = device ) * 100
144+ x_rep = x .view (- 1 , 1 ).repeat (1 , 100_000 )
145+
146+ func = torch .compile (
147+ _fp32_to_bf16_sr , fullgraph = True , dynamic = False , disable = not compile
148+ )
149+ out_plain = func (x_rep )
150+
151+ mesh = init_device_mesh (device , (1 ,))
152+ x_dt = DTensor .from_local (x_rep , mesh , [Replicate ()], run_check = False )
153+ out_dt = func (x_dt )
154+
155+ assert isinstance (out_dt , DTensor )
156+ torch .testing .assert_close (out_dt .to_local (), out_plain )
157+ finally :
158+ if created_pg :
159+ dist .destroy_process_group ()
160+
122161
123162class TestOptim (TestCase ):
124163 @parametrize (
0 commit comments