Skip to content

Commit a703422

Browse files
committed
up
1 parent 77c2491 commit a703422

File tree

3 files changed

+5
-9
lines changed

3 files changed

+5
-9
lines changed

test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,10 @@
2626
from torchao.quantization.granularity import PerGroup
2727
from torchao.quantization.quantize_.common import PackingFormat
2828
from torchao.quantization.utils import compute_error
29-
from torchao.utils import (
30-
TORCH_VERSION_AT_LEAST_2_8,
31-
)
29+
from torchao.utils import torch_version_at_least
3230

3331

34-
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
32+
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
3533
class TestIntxUnpackedTensor(TestCase):
3634
def setUp(self):
3735
self.config = IntxWeightOnlyConfig(

torchao/quantization/quant_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
_DTYPE_TO_QVALUE_BOUNDS,
119119
MappingType,
120120
ZeroPointDomain,
121+
quantize_affine,
121122
)
122123
from .subclass import (
123124
QuantizedLinearWeightBase,
@@ -797,7 +798,6 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
797798
act_mapping_type = config.act_mapping_type
798799
layout = config.layout
799800
packing_format = config.packing_format
800-
compute_target = config.compute_target
801801

802802
assert weight.dim() == 2, (
803803
f"Int8DynamicActivationIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}"

torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
)
2020
from torchao.quantization.utils import _get_per_token_block_size
2121
from torchao.utils import (
22-
TORCH_VERSION_AT_LEAST_2_5,
2322
TorchAOBaseTensor,
2423
fill_defaults,
2524
)
@@ -300,6 +299,5 @@ def _(func, types, args, kwargs):
300299

301300
IntxUnpackedTensor.__module__ = "torchao.quantization"
302301

303-
if TORCH_VERSION_AT_LEAST_2_5:
304-
# Allow a model with IntxUnpackedTensor weights to be loaded with `weights_only=True`
305-
torch.serialization.add_safe_globals([IntxUnpackedTensor])
302+
# Allow a model with IntxUnpackedTensor weights to be loaded with `weights_only=True`
303+
torch.serialization.add_safe_globals([IntxUnpackedTensor])

0 commit comments

Comments
 (0)