Skip to content

Migrate xnnpack/vulkan/boltnn pt2e from torch.ao to torchao #11363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -386,15 +386,9 @@ exclude_patterns = [
"third-party/**",
# TODO: remove exceptions as we migrate
# backends
"backends/vulkan/quantizer/**",
"backends/vulkan/test/**",
"backends/xnnpack/quantizer/**",
"backends/xnnpack/test/**",
"exir/tests/test_passes.py",
"extension/llm/export/builder.py",
"extension/llm/export/quantizer_lib.py",
"exir/tests/test_memory_planning.py",
"exir/backend/test/demos/test_xnnpack_qnnpack.py",
"backends/xnnpack/test/test_xnnpack_utils.py",
]

command = [
Expand Down
9 changes: 6 additions & 3 deletions backends/vulkan/quantizer/vulkan_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
_convert_scalars_to_attrs,
OP_TO_ANNOTATOR,
propagate_annotation,
QuantizationConfig,
)
from torch.ao.quantization.observer import PerChannelMinMaxObserver
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torch.fx import Node
from torchao.quantization.pt2e import PerChannelMinMaxObserver
from torchao.quantization.pt2e.quantizer import (
QuantizationConfig,
QuantizationSpec,
Quantizer,
)


__all__ = [
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
EdgeProgramManager,
ExecutorchProgramManager,
)

from torch.ao.quantization.quantizer import Quantizer
from torch.export import Dim, export, export_for_training, ExportedProgram

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

from torchao.quantization.pt2e.quantizer import Quantizer

ctypes.CDLL("libvulkan.so.1")


Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/test/test_vulkan_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
format_target_name,
)
from torch.ao.quantization.quantizer import Quantizer

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer import Quantizer

###################
## Common Models ##
Expand Down
21 changes: 6 additions & 15 deletions backends/xnnpack/partition/config/quant_affine_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,24 @@ class QuantizeAffineConfig(QDQAffineConfigs):
target_name = "quantize_affine.default"

def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
try:
import torchao.quantization.quant_primitives # noqa
import torchao.quantization.quant_primitives # noqa

return torch.ops.torchao.quantize_affine.default
except:
return None
return torch.ops.torchao.quantize_affine.default


class DeQuantizeAffineConfig(QDQAffineConfigs):
target_name = "dequantize_affine.default"

def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
try:
import torchao.quantization.quant_primitives # noqa
import torchao.quantization.quant_primitives # noqa

return torch.ops.torchao.dequantize_affine.default
except:
return None
return torch.ops.torchao.dequantize_affine.default


class ChooseQParamsAffineConfig(QDQAffineConfigs):
target_name = "choose_qparams_affine.default"

def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
try:
import torchao.quantization.quant_primitives # noqa
import torchao.quantization.quant_primitives # noqa

return torch.ops.torchao.choose_qparams_affine.default
except:
return None
return torch.ops.torchao.choose_qparams_affine.default
25 changes: 13 additions & 12 deletions backends/xnnpack/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,31 @@
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
_convert_scalars_to_attrs,
OP_TO_ANNOTATOR,
OperatorConfig,
OperatorPatternType,
propagate_annotation,
QuantizationConfig,
)
from torch.ao.quantization.fake_quantize import (
from torchao.quantization.pt2e import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
)
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.utils import _get_module_name_filter
from torchao.quantization.pt2e.quantizer import (
get_module_name_filter,
OperatorConfig,
OperatorPatternType,
QuantizationConfig,
QuantizationSpec,
Quantizer,
)


if TYPE_CHECKING:
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.fx import Node
from torchao.quantization.pt2e import ObserverOrFakeQuantizeConstructor


__all__ = [
Expand Down Expand Up @@ -140,7 +141,7 @@ def get_symmetric_quantization_config(
weight_qscheme = (
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
MinMaxObserver
)
if is_qat:
Expand Down Expand Up @@ -228,7 +229,7 @@ def _get_not_module_type_or_name_filter(
tp_list: list[Callable], module_name_list: list[str]
) -> Callable[[Node], bool]:
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
module_name_list_filters = [get_module_name_filter(m) for m in module_name_list]

def not_module_type_or_name_filter(n: Node) -> bool:
return not any(f(n) for f in module_type_filters + module_name_list_filters)
Expand Down Expand Up @@ -421,7 +422,7 @@ def _annotate_for_quantization_config(
module_name_list = list(self.module_name_config.keys())
for module_name, config in self.module_name_config.items():
self._annotate_all_patterns(
model, config, _get_module_name_filter(module_name)
model, config, get_module_name_filter(module_name)
)

tp_list = list(self.module_type_config.keys())
Expand Down
137 changes: 27 additions & 110 deletions backends/xnnpack/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,43 @@
# mypy: allow-untyped-defs
import itertools
import typing
from dataclasses import dataclass
from typing import Callable, NamedTuple, Optional
from typing import Callable, Optional

import torch
import torch.nn.functional as F
from executorch.backends.xnnpack.utils.utils import is_depthwise_conv
from torch._subclasses import FakeTensor
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
from torch.ao.quantization.pt2e.utils import (
_get_aten_graph_module_for_pattern,
_is_conv_node,
_is_conv_transpose_node,
from torch.fx import Node
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
SubgraphMatcherWithNameNodeMap,
)
from torch.ao.quantization.quantizer import (
from torchao.quantization.pt2e import WrapperModule
from torchao.quantization.pt2e.graph_utils import get_source_partitions
from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
get_bias_qspec,
get_input_act_qspec,
get_output_act_qspec,
get_weight_qspec,
OperatorConfig,
OperatorPatternType,
QuantizationAnnotation,
QuantizationConfig,
QuantizationSpec,
SharedQuantizationSpec,
)
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.fx import Node
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
SubgraphMatcherWithNameNodeMap,
from torchao.quantization.pt2e.utils import (
_get_aten_graph_module_for_pattern,
_is_conv_node,
_is_conv_transpose_node,
get_new_attr_name_with_prefix,
)
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions

__all__ = [
"OperatorConfig",
"OperatorPatternType",
"QuantizationConfig",
"QuantizationSpec",
"get_input_act_qspec",
"get_output_act_qspec",
"get_weight_qspec",
Expand All @@ -43,23 +47,6 @@
]


# In the absence of better name, just winging it with QuantizationConfig
@dataclass(eq=True, frozen=True)
class QuantizationConfig:
input_activation: Optional[QuantizationSpec]
output_activation: Optional[QuantizationSpec]
weight: Optional[QuantizationSpec]
bias: Optional[QuantizationSpec]
# TODO: remove, since we can use observer_or_fake_quant_ctr to express this
is_qat: bool = False


# Use Annotated because list[Callable].__module__ is read-only.
OperatorPatternType = typing.Annotated[list[Callable], None]
OperatorPatternType.__module__ = (
"executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils"
)

AnnotatorType = Callable[
[
torch.fx.GraphModule,
Expand All @@ -78,19 +65,6 @@ def decorator(annotator: AnnotatorType) -> None:
return decorator


class OperatorConfig(NamedTuple):
# fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
# Basically we are mapping a quantization config to some list of patterns.
# a pattern is defined as a list of nn module, function or builtin function names
# e.g. [nn.Conv2d, torch.relu, torch.add]
# We have not resolved whether fusion can be considered internal details of the
# quantizer hence it does not need communication to user.
# Note this pattern is not really informative since it does not really
# tell us the graph structure resulting from the list of ops.
config: QuantizationConfig
operators: list[OperatorPatternType]


def is_relu_node(node: Node) -> bool:
"""
Check if a given node is a relu node
Expand Down Expand Up @@ -124,63 +98,6 @@ def _mark_nodes_as_annotated(nodes: list[Node]):
node.meta["quantization_annotation"]._annotated = True


def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
if quantization_config.input_activation is None:
return None
quantization_spec: QuantizationSpec = quantization_config.input_activation
assert quantization_spec.qscheme in [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
]
return quantization_spec


def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
if quantization_config.output_activation is None:
return None
quantization_spec: QuantizationSpec = quantization_config.output_activation
assert quantization_spec.qscheme in [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
]
return quantization_spec


def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
assert quantization_config is not None
if quantization_config.weight is None:
return None
quantization_spec: QuantizationSpec = quantization_config.weight
if quantization_spec.qscheme not in [
torch.per_tensor_symmetric,
torch.per_channel_symmetric,
None,
]:
raise ValueError(
f"Unsupported quantization_spec {quantization_spec} for weight"
)
return quantization_spec


def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
assert quantization_config is not None
if quantization_config.bias is None:
return None
quantization_spec: QuantizationSpec = quantization_config.bias
assert (
quantization_spec.dtype == torch.float
), "Only float dtype for bias is supported for bias right now"
return quantization_spec


@register_annotator("linear")
def _annotate_linear(
gm: torch.fx.GraphModule,
Expand All @@ -204,25 +121,25 @@ def _annotate_linear(
bias_node = node.args[2]

if _is_annotated([node]) is False: # type: ignore[list-item]
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
act_node,
input_act_qspec,
)
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight_node,
weight_qspec,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
bias_node,
bias_qspec,
)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, output_act_qspec)
annotate_output_qspec(node, output_act_qspec)
_mark_nodes_as_annotated(nodes_to_mark_annotated)
annotated_partitions.append(nodes_to_mark_annotated)

Expand Down Expand Up @@ -572,7 +489,7 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
"output": output,
}

return _WrapperModule(_conv_bn)
return WrapperModule(_conv_bn)

# Needed for matching, otherwise the matches gets filtered out due to unused
# nodes returned by batch norm
Expand Down
Loading
Loading