From 60d5cfcd277968a69f8088d2f1c25f331d5eaf97 Mon Sep 17 00:00:00 2001 From: Tobias van der Werff <33268192+tobiasvanderwerff@users.noreply.github.com> Date: Tue, 24 Sep 2024 19:41:39 +0200 Subject: [PATCH] Fix failing FP6 benchmark --- benchmarks/benchmark_fp6.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index e9f9d21398..9b8dcf3387 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -1,7 +1,7 @@ import torch import pandas as pd import torch.nn.functional as F -from torchao.dtypes import to_affine_quantized_floatx +from torchao.dtypes import to_affine_quantized_fpx from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm @@ -9,7 +9,7 @@ def benchmark(m: int, k: int, n: int): float_data = torch.randn(n, k, dtype=torch.half, device="cuda") - fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2)) + fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayoutType(3, 2)) fp16_weight = fp6_weight.dequantize(torch.half) fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")