-
Notifications
You must be signed in to change notification settings - Fork 447
Description
Summary
This RFC proposes a code organization strategy to support multiple hardware backends (e.g., Ascend NPU) in Liger-Kernel while maintaining clean separation from the default CUDA implementation.
Motivation
As referenced in RFC #954, there is growing demand for Ascend NPU support in Liger-Kernel. Different hardware backends may require:
- Device-specific kernel implementations
- Different performance tuning parameters
- Hardware-specific optimizations
A well-designed code organization ensures vendor-specific adaptations remain isolated from the core codebase.
Design Goals
- Clean Separation: Vendor code is isolated from default implementations
- Incremental Override: Vendors only implement kernels that need adaptation
- Zero User Code Changes: Replacement is transparent to users
- Easy Extension: Adding new vendors requires minimal boilerplate
Architecture
Directory Structure
src/liger_kernel/ops/
├── __init__.py # Exports ops + runs replacement logic
├── geglu.py # Default implementation
├── rms_norm.py
├── ...
└── backends/
├── __init__.py # Imports vendor packages to trigger
├── registry.py # VendorInfo, register_vendor(), VENDOR_REGISTRY
└── _ascend/ # Ascend vendor (supports NPU)
├── __init__.py # Calls register_vendor()
└── ops/
├── __init__.py # Exports vendor implementations
└── geglu.py # NPU-specific GEGLU
Core Components
1. VendorInfo
@dataclass
class VendorInfo:
vendor: str # e.g., "ascend", "intel"
device: str # e.g., "npu", "xpu"
module_path: str # e.g., "liger_kernel.ops.backends._ascend.ops"2. Vendor Registry
VENDOR_REGISTRY: dict[str, VendorInfo] = {}
def register_vendor(vendor_info: VendorInfo) -> None:
VENDOR_REGISTRY[vendor_info.device] = vendor_info
def get_vendor_for_device(device: str) -> Optional[VendorInfo]:
return VENDOR_REGISTRY.get(device)3. Vendor Self-Registration
Each vendor registers itself in its __init__.py:
from liger_kernel.ops.backends.registry import VendorInfo, register_vendor
register_vendor(
VendorInfo(
vendor="ascend",
device="npu",
module_path="liger_kernel.ops.backends._ascend.ops",
)
)4. Replacement Logic
When liger_kernel.ops is imported, the replacement logic executes automatically:
- Detect current device via
infer_device() - If device is CUDA, use default implementations (no replacement)
- Look up
VendorInfofromVENDOR_REGISTRYby device type - Dynamically import vendor's ops module
- Replace/add symbols in
liger_kernel.opsnamespace viaglobals()
Vendors control exports via __all__. If not defined, all public symbols are auto-discovered.
Key Design Decisions
1. Vendor-based Directory Naming
Directories are named by vendor (e.g., _ascend, _intel).
2. Self-Registration Pattern
Each vendor registers itself in its __init__.py. This decouples vendor configuration from the central registry, allowing vendors to manage their own metadata independently.
3. Module-level Replacement
Replacement happens at the liger_kernel.ops package level:
from liger_kernel.ops import LigerGELUMulFunction— automatically replaced for vendor devicesfrom liger_kernel.ops.geglu import LigerGELUMulFunction— always uses default implementation
4. Explicit Export Control
Vendors use __all__ to explicitly declare exports. If not defined, all public symbols are auto-discovered.
Usage
For Users
No changes required. Existing code works transparently:
from liger_kernel.ops import LigerGELUMulFunction
# Automatically uses vendor implementation on corresponding devicesFor Vendor Contributors
- Create
backends/_<vendor>/directory structure - Register vendor in
_<vendor>/__init__.py - Add import in
backends/__init__.py - Implement ops in
_<vendor>/ops/ - Export in
_<vendor>/ops/__init__.py
Migration
Codebase Changes
Imports in transformers/ are updated from submodule level to package level:
# Before
from liger_kernel.ops.geglu import LigerGELUMulFunction
# After
from liger_kernel.ops import LigerGELUMulFunctionBackward Compatibility
- Direct submodule imports remain functional and always return default implementations
- No API changes for end users
Example: Ascend NPU GEGLU
# backends/_ascend/ops/geglu.py
import torch
class LigerGELUMulFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
# NPU-specific implementation with BLOCK_SIZE_SUB
...
@staticmethod
def backward(ctx, dc):
...
def geglu_forward(a, b):
# NPU kernel with UB overflow prevention
...# backends/_ascend/ops/__init__.py
from .geglu import (
LigerGELUMulFunction,
geglu_forward,
)
__all__ = ["LigerGELUMulFunction", "geglu_forward"]