Skip to content

[RFC] Code Organization for New Backend Support #965

@pillumina

Description

@pillumina

Summary

This RFC proposes a code organization strategy to support multiple hardware backends (e.g., Ascend NPU) in Liger-Kernel while maintaining clean separation from the default CUDA implementation.

Motivation

As referenced in RFC #954, there is growing demand for Ascend NPU support in Liger-Kernel. Different hardware backends may require:

  • Device-specific kernel implementations
  • Different performance tuning parameters
  • Hardware-specific optimizations

A well-designed code organization ensures vendor-specific adaptations remain isolated from the core codebase.

Design Goals

  1. Clean Separation: Vendor code is isolated from default implementations
  2. Incremental Override: Vendors only implement kernels that need adaptation
  3. Zero User Code Changes: Replacement is transparent to users
  4. Easy Extension: Adding new vendors requires minimal boilerplate

Architecture

Directory Structure

src/liger_kernel/ops/
├── __init__.py              # Exports ops + runs replacement logic
├── geglu.py                 # Default implementation
├── rms_norm.py
├── ...
└── backends/
    ├── __init__.py          # Imports vendor packages to trigger 
    ├── registry.py          # VendorInfo, register_vendor(), VENDOR_REGISTRY
    └── _ascend/             # Ascend vendor (supports NPU)
        ├── __init__.py      # Calls register_vendor()
        └── ops/
            ├── __init__.py  # Exports vendor implementations
            └── geglu.py     # NPU-specific GEGLU

Core Components

1. VendorInfo

@dataclass
class VendorInfo:
    vendor: str       # e.g., "ascend", "intel"
    device: str       # e.g., "npu", "xpu"
    module_path: str  # e.g., "liger_kernel.ops.backends._ascend.ops"

2. Vendor Registry

VENDOR_REGISTRY: dict[str, VendorInfo] = {}

def register_vendor(vendor_info: VendorInfo) -> None:
    VENDOR_REGISTRY[vendor_info.device] = vendor_info

def get_vendor_for_device(device: str) -> Optional[VendorInfo]:
    return VENDOR_REGISTRY.get(device)

3. Vendor Self-Registration

Each vendor registers itself in its __init__.py:

from liger_kernel.ops.backends.registry import VendorInfo, register_vendor

register_vendor(
    VendorInfo(
        vendor="ascend",
        device="npu",
        module_path="liger_kernel.ops.backends._ascend.ops",
    )
)

4. Replacement Logic

When liger_kernel.ops is imported, the replacement logic executes automatically:

  1. Detect current device via infer_device()
  2. If device is CUDA, use default implementations (no replacement)
  3. Look up VendorInfo from VENDOR_REGISTRY by device type
  4. Dynamically import vendor's ops module
  5. Replace/add symbols in liger_kernel.ops namespace via globals()

Vendors control exports via __all__. If not defined, all public symbols are auto-discovered.

Key Design Decisions

1. Vendor-based Directory Naming

Directories are named by vendor (e.g., _ascend, _intel).

2. Self-Registration Pattern

Each vendor registers itself in its __init__.py. This decouples vendor configuration from the central registry, allowing vendors to manage their own metadata independently.

3. Module-level Replacement

Replacement happens at the liger_kernel.ops package level:

  • from liger_kernel.ops import LigerGELUMulFunction — automatically replaced for vendor devices
  • from liger_kernel.ops.geglu import LigerGELUMulFunction — always uses default implementation

4. Explicit Export Control

Vendors use __all__ to explicitly declare exports. If not defined, all public symbols are auto-discovered.

Usage

For Users

No changes required. Existing code works transparently:

from liger_kernel.ops import LigerGELUMulFunction
# Automatically uses vendor implementation on corresponding devices

For Vendor Contributors

  1. Create backends/_<vendor>/ directory structure
  2. Register vendor in _<vendor>/__init__.py
  3. Add import in backends/__init__.py
  4. Implement ops in _<vendor>/ops/
  5. Export in _<vendor>/ops/__init__.py

Migration

Codebase Changes

Imports in transformers/ are updated from submodule level to package level:

# Before
from liger_kernel.ops.geglu import LigerGELUMulFunction

# After
from liger_kernel.ops import LigerGELUMulFunction

Backward Compatibility

  • Direct submodule imports remain functional and always return default implementations
  • No API changes for end users

Example: Ascend NPU GEGLU

# backends/_ascend/ops/geglu.py
import torch

class LigerGELUMulFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        # NPU-specific implementation with BLOCK_SIZE_SUB
        ...
    
    @staticmethod
    def backward(ctx, dc):
        ...

def geglu_forward(a, b):
    # NPU kernel with UB overflow prevention
    ...
# backends/_ascend/ops/__init__.py
from .geglu import (
    LigerGELUMulFunction,
    geglu_forward,
)

__all__ = ["LigerGELUMulFunction", "geglu_forward"]

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions