Skip to content

Add support for bmm for fbgemm config #2337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/dtypes/test_fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def setUp(self):
weight_dtype=e4m3_dtype,
output_dtype=torch.bfloat16,
)
self.bmm_config = FbgemmConfig(
input_dtype=e4m3_dtype,
weight_dtype=e4m3_dtype,
output_dtype=torch.bfloat16,
transpose_input=True,
)

def test_linear(self):
dtype = torch.bfloat16
Expand Down Expand Up @@ -106,6 +112,25 @@ def test_slice_and_copy_(self):
# making sure param.data is updated
assert param.data.float8_data[0][0] != orig_value

def test_bmm(self):
class M(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight

def forward(self, x):
return torch.bmm(x, self.weight)

dtype = torch.bfloat16
device = "cuda"
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
m = M(weight).eval()
original = m(input)
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 20)


if __name__ == "__main__":
run_tests()
26 changes: 26 additions & 0 deletions test/dtypes/test_fbgemm_int4.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def setUp(self):
output_dtype=torch.bfloat16,
block_size=[1, 128],
)
self.bmm_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, 1, 128],
transpose_input=True,
)

def test_linear(self):
dtype = torch.bfloat16
Expand Down Expand Up @@ -111,6 +118,25 @@ def test_slice_and_copy_(self):
# making sure param.data is updated
assert param.data.packed_weight[0][0] != orig_value

def test_bmm(self):
class M(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight

def forward(self, x):
return torch.bmm(x, self.weight)

dtype = torch.bfloat16
device = "cuda"
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
m = M(weight).eval()
original = m(input)
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 18)


if __name__ == "__main__":
run_tests()
40 changes: 40 additions & 0 deletions torchao/dtypes/fbgemm_fp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def from_float(
cls,
w: torch.Tensor,
activation_scale_ub: Optional[float] = None,
transpose_input: bool = False,
):
if activation_scale_ub is None:
activation_scale_ub = 1200.0
Expand All @@ -88,6 +89,12 @@ def from_float(
dtype=torch.float,
device=w.device,
)
if transpose_input:
if w.ndim == 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe just w.transpose(-1, -2)

w = w.transpose(1, 2)
else:
w = w.t()

wq, w_scale = torch.ops.triton.quantize_fp8_row(w)
# wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
dtype = w.dtype
Expand Down Expand Up @@ -141,6 +148,39 @@ def _(func, types, args, kwargs):
return res


@implements(torch.bmm)
def _(func, types, args, kwargs):
input_tensor, weight_tensor = (
args[0],
args[1],
)
if not input_tensor.is_floating_point():
raise NotImplementedError(
f"{func} is not implemented for non floating point input"
)

orig_act_size = input_tensor.size()

# not used
num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device)
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ot use num_tokens feels weird, maybe make an issue on fbgemm? or update the op to not need

input_tensor, num_tokens, weight_tensor.activation_scale_ub
)

a_data = xq
b_data = weight_tensor.float8_data
orig_out_features = b_data.shape[-2]

res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
a_data,
b_data,
x_scale,
weight_tensor.scale,
)
res = res.reshape(*orig_act_size[:-1], orig_out_features)
return res


@implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
Expand Down
31 changes: 31 additions & 0 deletions torchao/dtypes/fbgemm_int4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,20 @@ def from_float(
cls,
w: torch.Tensor,
block_size: List[int],
transpose_input: bool = False,
):
assert len(block_size) == w.ndim, (
f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}"
)
if int4_row_quantize_zp is None:
raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0")

if transpose_input:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here

if w.ndim == 3:
w = w.transpose(1, 2)
else:
w = w.t()

group_size = block_size[-1]
original_shape = w.shape

Expand Down Expand Up @@ -146,6 +153,30 @@ def _(func, types, args, kwargs):
return res


@implements(torch.bmm)
def _(func, types, args, kwargs):
input_tensor, weight_tensor = (
args[0],
args[1],
)
if not input_tensor.is_floating_point():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is this guard needed? Like is this a common situation to run into

raise NotImplementedError(
f"{func} is not implemented for non floating point input"
)

orig_act_size = input_tensor.size()
orig_out_features = weight_tensor.shape[-2]

res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched(
input_tensor,
weight_tensor.packed_weight.contiguous(),
weight_tensor.scale,
weight_tensor.zero_point,
)
res = res.reshape(*orig_act_size[:-1], orig_out_features)
return res


@implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
Expand Down
3 changes: 3 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1991,6 +1991,7 @@ class FbgemmConfig(AOBaseConfig):
output_dtype: torch.dtype
block_size: Optional[List[int]] = None
activation_scale_ub: Optional[float] = None
transpose_input: bool = False


@register_quantize_module_handler(FbgemmConfig)
Expand Down Expand Up @@ -2018,6 +2019,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
weight = to_fbgemm_int4(
module.weight,
config.block_size,
config.transpose_input,
)
module.weight = torch.nn.Parameter(weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
Expand All @@ -2029,6 +2031,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
weight = to_fbgemm_fp8(
module.weight,
config.activation_scale_ub,
config.transpose_input,
)
module.weight = torch.nn.Parameter(weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
Expand Down
Loading