diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 8d4d58fd6b..0dd14672ff 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -3,6 +3,7 @@ CastConfig, Float8GemmConfig, Float8LinearConfig, + ScalingGranularity, ScalingType, ) from torchao.float8.float8_linear_utils import ( @@ -29,12 +30,14 @@ GemmInputRole, LinearMMConfig, Float8MMConfig, + ScalingGranularity, ] ) __all__ = [ # configuration "ScalingType", + "ScalingGranularity", "Float8GemmConfig", "Float8LinearConfig", "CastConfig",