Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions test/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,45 @@ def test_bf16_stochastic_round(self, device, compile):
# must cast BF16 tensor back to FP32 so that .mean() is accurate
torch.testing.assert_close(x_rep_bf16.float().mean(1), x, atol=3e-5, rtol=3e-5)

@parametrize("device", _DEVICES)
@parametrize("compile", [False, True])
def test_bf16_stochastic_round_dtensor(self, device, compile):
pytest.importorskip("torch.distributed")
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Replicate
from torch.distributed.device_mesh import init_device_mesh

created_pg = False
if dist.is_available() and not dist.is_initialized():
store = dist.TCPStore("127.0.0.1", 29500, 1, True)
dist.init_process_group(
backend="gloo",
store=store,
rank=0,
world_size=1,
)
created_pg = True

try:
torch.manual_seed(common_utils.SEED)
x = torch.rand(32, device=device) * 100
x_rep = x.view(-1, 1).repeat(1, 100_000)

func = torch.compile(
_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile
)
out_plain = func(x_rep)

mesh = init_device_mesh(device, (1,))
x_dt = DTensor.from_local(x_rep, mesh, [Replicate()], run_check=False)
out_dt = func(x_dt)

assert isinstance(out_dt, DTensor)
torch.testing.assert_close(out_dt.to_local(), out_plain)
finally:
if created_pg:
dist.destroy_process_group()


class TestOptim(TestCase):
@parametrize(
Expand Down
13 changes: 11 additions & 2 deletions torchao/optim/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import torch
from torch import Tensor
from torch.distributed.tensor import DTensor


# https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391
Expand Down Expand Up @@ -117,7 +118,7 @@ def dequant_with_qmap(codes: Tensor, qmap: Tensor, scale: Tensor):
return out.view(codes.shape)


def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
def _fp32_to_bf16_sr(_x_f32: Tensor) -> Tensor:
# For an FP32 number [a31, ..., a16, a15, ..., a0] to be converted to BF16
# - Round towards zero: [a31, ..., a16, 0, ..., 0]
# - Round away from zero: [a31, ..., a16+1, 0, ..., 0]
Expand All @@ -127,6 +128,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
# [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
#
# we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
is_dt = isinstance(_x_f32, DTensor)
x_f32 = _x_f32.to_local() if is_dt else _x_f32

rand_16bit = torch.randint(
0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32
)
Expand All @@ -142,4 +146,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
)
# alternative, slightly faster
# x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000
return x_f32_bits.view(torch.float32).bfloat16()
x_bf16_trunc = x_f32_bits.view(torch.float32).bfloat16()

return DTensor.from_local(
x_bf16_trunc, _x_f32.device_mesh, _x_f32.placements,
run_check=False, shape=tuple(_x_f32.shape), stride=tuple(_x_f32.stride()),
) if is_dt else x_bf16_trunc