Skip to content

Add support for bmm and to for fbgemm Tensor #2337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/dtypes/test_fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def setUp(self):
weight_dtype=e4m3_dtype,
output_dtype=torch.bfloat16,
)
self.bmm_config = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=e4m3_dtype,
output_dtype=torch.bfloat16,
transpose_input=True,
)
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

def test_linear(self):
dtype = torch.bfloat16
Expand Down Expand Up @@ -106,6 +113,39 @@ def test_slice_and_copy_(self):
# making sure param.data is updated
assert param.data.float8_data[0][0] != orig_value

def test_bmm(self):
class M(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight

def forward(self, x):
return torch.bmm(x, self.weight)

dtype = torch.bfloat16
device = "cuda"
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
m = M(weight).eval()
original = m(input)
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 20)

def test_to_device(self):
for device in self.GPU_DEVICES:
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
linear.to(device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
linear.to(device=device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
linear.to(device)


if __name__ == "__main__":
run_tests()
41 changes: 41 additions & 0 deletions test/dtypes/test_fbgemm_int4.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def setUp(self):
output_dtype=torch.bfloat16,
block_size=[1, 128],
)
self.bmm_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
transpose_input=True,
)
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

def test_linear(self):
dtype = torch.bfloat16
Expand Down Expand Up @@ -111,6 +119,39 @@ def test_slice_and_copy_(self):
# making sure param.data is updated
assert param.data.packed_weight[0][0] != orig_value

def test_bmm(self):
class M(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight

def forward(self, x):
return torch.bmm(x, self.weight)

dtype = torch.bfloat16
device = "cuda"
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
m = M(weight).eval()
original = m(input)
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 18)

def test_to_device(self):
for device in self.GPU_DEVICES:
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
linear.to(device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
linear.to(device=device)

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear, self.config)
linear.to(device)


if __name__ == "__main__":
run_tests()
6 changes: 4 additions & 2 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
to_affine_quantized_intx,
to_affine_quantized_intx_static,
)
from .fbgemm_fp8_tensor import to_fbgemm_fp8
from .fbgemm_int4_tensor import to_fbgemm_int4
from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8
from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4
from .floatx import (
CutlassSemiSparseLayout,
Float8Layout,
Expand Down Expand Up @@ -64,5 +64,7 @@
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
"Int4XPULayout",
"to_fbgemm_int4",
"FbgemmInt4Tensor",
"to_fbgemm_fp8",
"FbgemmFp8Tensor",
]
50 changes: 45 additions & 5 deletions torchao/dtypes/fbgemm_fp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

__all__ = [
"to_fbgemm_fp8",
"FbgemmFp8Tensor",
]

aten = torch.ops.aten
Expand Down Expand Up @@ -74,11 +75,22 @@ def __repr__(self):
def _quantization_type(self):
return f"shape={self.shape}, activation_scale_ub={self.activation_scale_ub}, device={self.device}"

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs.pop("device")
return self.__class__(
self.float8_data.to(device),
self.scale.to(device),
self.activation_scale_ub.to(device),
self.dtype,
)

@classmethod
def from_float(
cls,
w: torch.Tensor,
activation_scale_ub: Optional[float] = None,
transpose_input: bool = False,
):
if activation_scale_ub is None:
activation_scale_ub = 1200.0
Expand All @@ -88,6 +100,12 @@ def from_float(
dtype=torch.float,
device=w.device,
)
if transpose_input:
if w.ndim == 3:
w = w.transpose(-1, -2)
else:
w = w.t()

wq, w_scale = torch.ops.triton.quantize_fp8_row(w)
# wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
dtype = w.dtype
Expand All @@ -110,11 +128,6 @@ def _(func, types, args, kwargs):
args[1],
args[2] if len(args) > 2 else None,
)
if not input_tensor.is_floating_point():
raise NotImplementedError(
f"{func} is not implemented for non floating point input"
)

orig_act_size = input_tensor.size()
orig_out_features = weight_tensor.shape[-2]

Expand All @@ -141,6 +154,33 @@ def _(func, types, args, kwargs):
return res


@implements(torch.bmm)
def _(func, types, args, kwargs):
input_tensor, weight_tensor = (
args[0],
args[1],
)
orig_act_size = input_tensor.size()
# not used
num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device)
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ot use num_tokens feels weird, maybe make an issue on fbgemm? or update the op to not need

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I checked with @jiawenliu64 and this arg is indeed only used in internal use cases, he was recommending to use the triton op, although I found the triton op is a bit slower, maybe it requires some tuning. I'll double check

input_tensor, num_tokens, weight_tensor.activation_scale_ub
)

a_data = xq
b_data = weight_tensor.float8_data
orig_out_features = b_data.shape[-2]

res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
a_data,
b_data,
x_scale,
weight_tensor.scale,
)
res = res.reshape(*orig_act_size[:-1], orig_out_features)
return res


@implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
Expand Down
43 changes: 38 additions & 5 deletions torchao/dtypes/fbgemm_int4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

__all__ = [
"to_fbgemm_int4",
"FbgemmInt4Tensor",
]

aten = torch.ops.aten
Expand Down Expand Up @@ -77,18 +78,36 @@ def __repr__(self):
def _quantization_type(self):
return f"shape={self.shape}, group_size={self.group_size}, device={self.device}"

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs.pop("device")
return self.__class__(
self.packed_weight.to(device),
self.scale.to(device),
self.zero_point.to(device),
self.group_size,
self.shape,
)

@classmethod
def from_float(
cls,
w: torch.Tensor,
block_size: List[int],
transpose_input: bool = False,
):
assert len(block_size) == w.ndim, (
f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}"
)
if int4_row_quantize_zp is None:
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")

if transpose_input:
if w.ndim == 3:
w = w.transpose(-1, -2)
else:
w = w.t()

group_size = block_size[-1]
original_shape = w.shape

Expand Down Expand Up @@ -126,11 +145,6 @@ def _(func, types, args, kwargs):
args[1],
args[2] if len(args) > 2 else None,
)
if not input_tensor.is_floating_point():
raise NotImplementedError(
f"{func} is not implemented for non floating point input"
)

orig_act_size = input_tensor.size()
orig_out_features = weight_tensor.shape[-2]

Expand All @@ -146,6 +160,25 @@ def _(func, types, args, kwargs):
return res


@implements(torch.bmm)
def _(func, types, args, kwargs):
input_tensor, weight_tensor = (
args[0],
args[1],
)
orig_act_size = input_tensor.size()
orig_out_features = weight_tensor.shape[-2]

res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched(
input_tensor,
weight_tensor.packed_weight.contiguous(),
weight_tensor.scale,
weight_tensor.zero_point,
)
res = res.reshape(*orig_act_size[:-1], orig_out_features)
return res


@implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
Expand Down
5 changes: 5 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1991,6 +1991,7 @@ class FbgemmConfig(AOBaseConfig):
output_dtype: torch.dtype
block_size: Optional[List[int]] = None
activation_scale_ub: Optional[float] = None
transpose_input: bool = False


@register_quantize_module_handler(FbgemmConfig)
Expand Down Expand Up @@ -2018,9 +2019,11 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
weight = to_fbgemm_int4(
module.weight,
config.block_size,
config.transpose_input,
)
module.weight = torch.nn.Parameter(weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module
elif (
(config.input_dtype == e4m3_dtype)
and (config.weight_dtype == e4m3_dtype)
Expand All @@ -2029,9 +2032,11 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
weight = to_fbgemm_fp8(
module.weight,
config.activation_scale_ub,
config.transpose_input,
)
module.weight = torch.nn.Parameter(weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module
else:
raise NotImplementedError(
f"{config} is not supported. supported input, weight, output kernel dtypes are: {_SUPPORTED_DTYPES}"
Expand Down
Loading