diff --git a/test/dtypes/test_fbgemm_fp8.py b/test/dtypes/test_fbgemm_fp8.py index 56cf5ea081..1e681d00f9 100644 --- a/test/dtypes/test_fbgemm_fp8.py +++ b/test/dtypes/test_fbgemm_fp8.py @@ -34,6 +34,13 @@ def setUp(self): weight_dtype=e4m3_dtype, output_dtype=torch.bfloat16, ) + self.bmm_config = FbgemmConfig( + input_dtype=e4m3_dtype, + weight_dtype=e4m3_dtype, + output_dtype=torch.bfloat16, + transpose_input=True, + ) + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] def test_linear(self): dtype = torch.bfloat16 @@ -106,6 +113,39 @@ def test_slice_and_copy_(self): # making sure param.data is updated assert param.data.float8_data[0][0] != orig_value + def test_bmm(self): + class M(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def forward(self, x): + return torch.bmm(x, self.weight) + + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(10, 32, 128, dtype=dtype, device=device) + weight = torch.randn(10, 128, 256, dtype=dtype, device=device) + m = M(weight).eval() + original = m(input) + quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) + quantized = m(input) + self.assertTrue(compute_error(original, quantized) > 20) + + def test_to_device(self): + for device in self.GPU_DEVICES: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, self.config) + linear.to(device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, self.config) + linear.to(device=device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, self.config) + linear.to(device) + if __name__ == "__main__": run_tests() diff --git a/test/dtypes/test_fbgemm_int4.py b/test/dtypes/test_fbgemm_int4.py index 25b71f0244..cba9d81ae0 100644 --- a/test/dtypes/test_fbgemm_int4.py +++ b/test/dtypes/test_fbgemm_int4.py @@ -34,6 +34,14 @@ def setUp(self): output_dtype=torch.bfloat16, block_size=[1, 128], ) + self.bmm_config = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 1, 128], + transpose_input=True, + ) + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] def test_linear(self): dtype = torch.bfloat16 @@ -111,6 +119,39 @@ def test_slice_and_copy_(self): # making sure param.data is updated assert param.data.packed_weight[0][0] != orig_value + def test_bmm(self): + class M(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def forward(self, x): + return torch.bmm(x, self.weight) + + dtype = torch.bfloat16 + device = "cuda" + input = torch.randn(10, 32, 128, dtype=dtype, device=device) + weight = torch.randn(10, 128, 256, dtype=dtype, device=device) + m = M(weight).eval() + original = m(input) + quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) + quantized = m(input) + self.assertTrue(compute_error(original, quantized) > 18) + + def test_to_device(self): + for device in self.GPU_DEVICES: + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, self.config) + linear.to(device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, self.config) + linear.to(device=device) + + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) + quantize_(linear, self.config) + linear.to(device) + if __name__ == "__main__": run_tests() diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 692d56ad31..581c3e4ecb 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -8,8 +8,8 @@ to_affine_quantized_intx, to_affine_quantized_intx_static, ) -from .fbgemm_fp8_tensor import to_fbgemm_fp8 -from .fbgemm_int4_tensor import to_fbgemm_int4 +from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8 +from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4 from .floatx import ( CutlassSemiSparseLayout, Float8Layout, @@ -64,5 +64,7 @@ "to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight", "Int4XPULayout", "to_fbgemm_int4", + "FbgemmInt4Tensor", "to_fbgemm_fp8", + "FbgemmFp8Tensor", ] diff --git a/torchao/dtypes/fbgemm_fp8_tensor.py b/torchao/dtypes/fbgemm_fp8_tensor.py index df7ce69de7..b6c1d72acc 100644 --- a/torchao/dtypes/fbgemm_fp8_tensor.py +++ b/torchao/dtypes/fbgemm_fp8_tensor.py @@ -18,6 +18,7 @@ __all__ = [ "to_fbgemm_fp8", + "FbgemmFp8Tensor", ] aten = torch.ops.aten @@ -74,11 +75,22 @@ def __repr__(self): def _quantization_type(self): return f"shape={self.shape}, activation_scale_ub={self.activation_scale_ub}, device={self.device}" + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + return self.__class__( + self.float8_data.to(device), + self.scale.to(device), + self.activation_scale_ub.to(device), + self.dtype, + ) + @classmethod def from_float( cls, w: torch.Tensor, activation_scale_ub: Optional[float] = None, + transpose_input: bool = False, ): if activation_scale_ub is None: activation_scale_ub = 1200.0 @@ -88,6 +100,12 @@ def from_float( dtype=torch.float, device=w.device, ) + if transpose_input: + if w.ndim == 3: + w = w.transpose(-1, -2) + else: + w = w.t() + wq, w_scale = torch.ops.triton.quantize_fp8_row(w) # wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) dtype = w.dtype @@ -110,11 +128,6 @@ def _(func, types, args, kwargs): args[1], args[2] if len(args) > 2 else None, ) - if not input_tensor.is_floating_point(): - raise NotImplementedError( - f"{func} is not implemented for non floating point input" - ) - orig_act_size = input_tensor.size() orig_out_features = weight_tensor.shape[-2] @@ -141,6 +154,33 @@ def _(func, types, args, kwargs): return res +@implements(torch.bmm) +def _(func, types, args, kwargs): + input_tensor, weight_tensor = ( + args[0], + args[1], + ) + orig_act_size = input_tensor.size() + # not used + num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device) + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( + input_tensor, num_tokens, weight_tensor.activation_scale_ub + ) + + a_data = xq + b_data = weight_tensor.float8_data + orig_out_features = b_data.shape[-2] + + res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( + a_data, + b_data, + x_scale, + weight_tensor.scale, + ) + res = res.reshape(*orig_act_size[:-1], orig_out_features) + return res + + @implements([aten.detach.default, aten.alias.default]) def _(func, types, args, kwargs): return return_and_correct_aliasing( diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/dtypes/fbgemm_int4_tensor.py index ab108fea06..c398442168 100644 --- a/torchao/dtypes/fbgemm_int4_tensor.py +++ b/torchao/dtypes/fbgemm_int4_tensor.py @@ -19,6 +19,7 @@ __all__ = [ "to_fbgemm_int4", + "FbgemmInt4Tensor", ] aten = torch.ops.aten @@ -77,11 +78,23 @@ def __repr__(self): def _quantization_type(self): return f"shape={self.shape}, group_size={self.group_size}, device={self.device}" + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + return self.__class__( + self.packed_weight.to(device), + self.scale.to(device), + self.zero_point.to(device), + self.group_size, + self.shape, + ) + @classmethod def from_float( cls, w: torch.Tensor, block_size: List[int], + transpose_input: bool = False, ): assert len(block_size) == w.ndim, ( f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" @@ -89,6 +102,12 @@ def from_float( if int4_row_quantize_zp is None: raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") + if transpose_input: + if w.ndim == 3: + w = w.transpose(-1, -2) + else: + w = w.t() + group_size = block_size[-1] original_shape = w.shape @@ -126,11 +145,6 @@ def _(func, types, args, kwargs): args[1], args[2] if len(args) > 2 else None, ) - if not input_tensor.is_floating_point(): - raise NotImplementedError( - f"{func} is not implemented for non floating point input" - ) - orig_act_size = input_tensor.size() orig_out_features = weight_tensor.shape[-2] @@ -146,6 +160,25 @@ def _(func, types, args, kwargs): return res +@implements(torch.bmm) +def _(func, types, args, kwargs): + input_tensor, weight_tensor = ( + args[0], + args[1], + ) + orig_act_size = input_tensor.size() + orig_out_features = weight_tensor.shape[-2] + + res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched( + input_tensor, + weight_tensor.packed_weight.contiguous(), + weight_tensor.scale, + weight_tensor.zero_point, + ) + res = res.reshape(*orig_act_size[:-1], orig_out_features) + return res + + @implements([aten.detach.default, aten.alias.default]) def _(func, types, args, kwargs): return return_and_correct_aliasing( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index be25b144a6..d8af23414b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1991,6 +1991,7 @@ class FbgemmConfig(AOBaseConfig): output_dtype: torch.dtype block_size: Optional[List[int]] = None activation_scale_ub: Optional[float] = None + transpose_input: bool = False @register_quantize_module_handler(FbgemmConfig) @@ -2018,9 +2019,11 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: weight = to_fbgemm_int4( module.weight, config.block_size, + config.transpose_input, ) module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module elif ( (config.input_dtype == e4m3_dtype) and (config.weight_dtype == e4m3_dtype) @@ -2029,9 +2032,11 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: weight = to_fbgemm_fp8( module.weight, config.activation_scale_ub, + config.transpose_input, ) module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module else: raise NotImplementedError( f"{config} is not supported. supported input, weight, output kernel dtypes are: {_SUPPORTED_DTYPES}"