Skip to content

Commit 41df655

Browse files
committed
Test DTensor bf16 stochastic round parity
1 parent 9a6c9c3 commit 41df655

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

test/test_low_bit_optim.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,45 @@ def test_bf16_stochastic_round(self, device, compile):
119119
# must cast BF16 tensor back to FP32 so that .mean() is accurate
120120
torch.testing.assert_close(x_rep_bf16.float().mean(1), x, atol=3e-5, rtol=3e-5)
121121

122+
@parametrize("device", _DEVICES)
123+
@parametrize("compile", [False, True])
124+
def test_bf16_stochastic_round_dtensor_cpu(self, device, compile):
125+
pytest.importorskip("torch.distributed")
126+
import torch.distributed as dist
127+
from torch.distributed.tensor import DTensor, Replicate
128+
from torch.distributed.device_mesh import init_device_mesh
129+
130+
created_pg = False
131+
if dist.is_available() and not dist.is_initialized():
132+
store = dist.TCPStore("127.0.0.1", 29500, 1, True)
133+
dist.init_process_group(
134+
backend="gloo",
135+
store=store,
136+
rank=0,
137+
world_size=1,
138+
)
139+
created_pg = True
140+
141+
try:
142+
torch.manual_seed(common_utils.SEED)
143+
x = torch.rand(32, device=device) * 100
144+
x_rep = x.view(-1, 1).repeat(1, 100_000)
145+
146+
func = torch.compile(
147+
_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile
148+
)
149+
out_plain = func(x_rep)
150+
151+
mesh = init_device_mesh(device, (1,))
152+
x_dt = DTensor.from_local(x_rep, mesh, [Replicate()], run_check=False)
153+
out_dt = func(x_dt)
154+
155+
assert isinstance(out_dt, DTensor)
156+
torch.testing.assert_close(out_dt.to_local(), out_plain)
157+
finally:
158+
if created_pg:
159+
dist.destroy_process_group()
160+
122161

123162
class TestOptim(TestCase):
124163
@parametrize(

0 commit comments

Comments
 (0)