Skip to content

Add Finegrained FP8 #11647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,8 @@ jobs:
additional_deps: []
- backend: "optimum_quanto"
test_location: "quanto"
- backend: "finegrained_fp8"
test_location: "finegrained_fp8"
additional_deps: []
runs-on:
group: aws-g6e-xlarge-plus
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@
title: torchao
- local: quantization/quanto
title: quanto
- local: quantization/finegrained_fp8
title: finegrained_fp8
title: Quantization Methods
- sections:
- local: optimization/fp16
Expand Down
5 changes: 5 additions & 0 deletions docs/source/en/api/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui

[[autodoc]] TorchAoConfig

## FinegrainedFP8Config

[[autodoc]] FinegrainedFP8Config

## DiffusersQuantizer

[[autodoc]] quantizers.base.DiffusersQuantizer

15 changes: 15 additions & 0 deletions docs/source/en/quantization/finegrained_fp8.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# FinegrainedFP8

## Overview

## Usage

1 change: 0 additions & 1 deletion docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ pipeline_quant_config = PipelineQuantizationConfig(
components_to_quantize=["transformer", "text_encoder_2"],
)
```

Pass the `pipeline_quant_config` to [`~DiffusionPipeline.from_pretrained`] to quantize the pipeline.

```py
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@
else:
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")

_import_structure["quantizers.quantization_config"].append("FinegrainedFP8Config")

try:
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -725,6 +727,8 @@
else:
from .quantizers.quantization_config import QuantoConfig

from .quantizers.quantization_config import FinegrainedFP8Config

try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from typing import Dict, Optional, Union

from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .finegrained_fp8 import FinegrainedFP8Quantizer
from .gguf import GGUFQuantizer
from .quantization_config import (
BitsAndBytesConfig,
FinegrainedFP8Config,
GGUFQuantizationConfig,
QuantizationConfigMixin,
QuantizationMethod,
Expand All @@ -39,6 +41,7 @@
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
"finegrained_fp8": FinegrainedFP8Quantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -47,6 +50,7 @@
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
"finegrained_fp8": FinegrainedFP8Config,
}


Expand Down
1 change: 1 addition & 0 deletions src/diffusers/quantizers/finegrained_fp8/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .finegrained_fp8_quantizer import FinegrainedFP8Quantizer
205 changes: 205 additions & 0 deletions src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from ...utils import get_module_from_name, is_accelerate_available, is_torch_available, logging
from ..base import DiffusersQuantizer


if is_torch_available():
import torch

logger = logging.get_logger(__name__)

if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin


class FinegrainedFP8Quantizer(DiffusersQuantizer):
"""
FP8 quantization implementation supporting both standard and MoE models.
Supports both e4m3fn formats based on platform.
"""

requires_parameters_quantization = True
requires_calibration = False
required_packages = ["accelerate"]

def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config

def validate_environment(self, *args, **kwargs):
if not is_torch_available():
raise ImportError(
"Using fp8 quantization requires torch >= 2.1.0"
"Please install the latest version of torch ( pip install --upgrade torch )"
)

if not is_accelerate_available():
raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)")

if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
raise ValueError(
"Converting into FP8 weights from tf/flax weights is currently not supported, "
"please make sure the weights are in PyTorch format."
)

if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability
if (major < 8) or (major == 8 and minor < 9):
raise ValueError(
"FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)"
f", actual = `{major}.{minor}`"
)

device_map = kwargs.get("device_map", None)
if device_map is None:
logger.warning_once(
"You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
)
elif device_map is not None:
if (
not self.pre_quantized
and isinstance(device_map, dict)
and ("cpu" in device_map.values() or "disk" in device_map.values())
):
raise ValueError(
"You are attempting to load an FP8 model with a device_map that contains a cpu/disk device."
"This is not supported when the model is quantized on the fly. "
"Please use a quantized checkpoint or remove the cpu/disk device from the device_map."
)

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
logger.info("Setting torch_dtype to torch.float32 as no torch_dtype was specified in from_pretrained")
torch_dtype = torch.float32
return torch_dtype

def create_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
**kwargs,
):
"""
Quantizes weights to FP8 format using Block-wise quantization
"""
# print("############ create quantized param ########")
from accelerate.utils import set_module_tensor_to_device

set_module_tensor_to_device(model, param_name, target_device, param_value)

module, tensor_name = get_module_from_name(model, param_name)

# Get FP8 min/max values
fp8_min = torch.finfo(torch.float8_e4m3fn).min
fp8_max = torch.finfo(torch.float8_e4m3fn).max

block_size_m, block_size_n = self.quantization_config.weight_block_size

rows, cols = param_value.shape[-2:]

if rows % block_size_m != 0 or cols % block_size_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n}) for {param_name}"
)
param_value_orig_shape = param_value.shape

param_value = param_value.reshape(
rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
).permute(0, 2, 1, 3)

# Calculate scaling factor for each block
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
scale = fp8_max / max_abs
scale_orig_shape = scale.shape
scale = scale.unsqueeze(-1).unsqueeze(-1)

# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

quantized_param = quantized_param.permute(0, 2, 1, 3)
# Reshape back to matrix shape
quantized_param = quantized_param.reshape(param_value_orig_shape)

# Reshape scale to match the number of blocks
scale = scale.reshape(scale_orig_shape).reciprocal()

# Load into the model
module._parameters[tensor_name] = quantized_param.to(target_device)
module._parameters["weight_scale_inv"] = scale.to(target_device)

def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
from .utils import FP8Linear

module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, FP8Linear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
raise ValueError("Expect quantized weights but got an unquantized weight")
return False
else:
if tensor_name == "weight_scale_inv":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
return False

def _process_model_before_weight_loading(
self,
model: "ModelMixin",
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from .utils import replace_with_fp8_linear

if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)

model = replace_with_fp8_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
)

model.config.quantization_config = self.quantization_config

def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
return model

def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from .utils import FP8Linear

not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, FP8Linear):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
and not missing.endswith(".weight")
and not missing.endswith(".bias")
):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]

def is_serializable(self, safe_serialization=None):
return True

@property
def is_trainable(self) -> bool:
return False

def get_cuda_warm_up_factor(self):
# Pre-processing is done cleanly, so we can allocate everything here
return 2
Loading
Loading