Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions water/include/water/Dialect/Wave/IR/WaveAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -476,4 +476,169 @@ def WaveReadWriteBoundsAttr : AttrDef<WaveDialect, "WaveReadWriteBounds"> {
}];
}

def WaveMemoryAccessPatternAttr : AttrDef<WaveDialect, "WaveMemoryAccessPattern"> {
let mnemonic = "memory_access_pattern";
let description = [{
This attribute specifies how memory access should be handled during lowering,
particularly for operations that may require LDS (Local Data Store) promotion.

LDS promotion transforms inefficient scalar memory accesses into efficient vectorized
accesses by using Local Data Store (LDS) as an intermediate buffer. The transformation
converts:

**Original Pattern**: Register -> Global Memory (scalar, potentially uncoalesced)

**LDS Promotion Pattern**:
1. Register -> LDS (scalar stores, same as original pattern but to LDS)
2. LDS -> Register (vectorized loads from LDS)
3. Register -> Global Memory (vectorized stores to global memory)

## Index Transformation Logic

### Step 1: Register -> LDS Store
The original global memory indices are transformed to LDS indices by subtracting
the LDS block's base address in global memory coordinates:
```
lds_store_indices = original_global_indices - lds_block_global_base
```

Example:
- Original: `global[WG0 * BLOCK_M + T0 * 4 + offset]`
- LDS base: `WG0 * BLOCK_M`
- LDS store: `lds[(WG0 * BLOCK_M + T0 * 4 + offset) - (WG0 * BLOCK_M)] = lds[T0 * 4 + offset]`

### Step 2: LDS -> Register Load (Vectorized)
Uses `lds_load_mapping` to define vectorized access patterns within the LDS:
```
lds_load_start = start_expr(thread_indices)
lds_load_vector_size = step_expr(thread_indices)
lds_load_stride = stride_expr(thread_indices)
```

### Step 3: Register -> Global Store (Vectorized)
Uses `global_store_mapping` to define vectorized stores back to global memory:
```
global_store_start = start_expr(thread_indices)
global_store_vector_size = step_expr(thread_indices)
global_store_stride = stride_expr(thread_indices)
```

## Complete Example
```mlir
#wave.memory_access_pattern<
use_lds_promotion = true,
group_id = "mfma_result_0",

// LDS block placement: where in global memory this LDS block maps to
lds_block_global_base = #wave.expr_list<
[#wave.index_symbol<WG0>, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)
>,

// LDS allocation size
lds_block_shape = #wave.expr_list<
[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (BLOCK_M, BLOCK_N)
>,

// Vectorized LDS -> Register: each thread loads 64 elements starting at T0*64
lds_load_mapping = #wave.index_mapping<
[#wave.index_symbol<T0>] -> (T0 * 64, 64, 1)
>,

// Vectorized Register -> Global: each thread stores 64 elements
global_store_mapping = #wave.index_mapping<
[#wave.index_symbol<WG0>, #wave.index_symbol<T0>, #wave.symbol<"BLOCK_M">]
-> (WG0 * BLOCK_M + T0 * 64, 64, 1)
>
>
```

## Parameters:
- **use_lds_promotion**: Whether to enable LDS promotion transformation
- **group_id**: String identifier for grouping related operations that share the same LDS allocation
- **lds_block_global_base**: Global memory base address that this LDS block represents (for index transformation)
- **lds_block_shape**: Shape/size of the LDS block allocation in elements
- **lds_load_mapping**: Index mapping for vectorized loads from LDS (start, vector_size, stride)
- **global_store_mapping**: Index mapping for vectorized stores to global memory (start, vector_size, stride)

## Verification Requirements:
- All index mappings must have the same rank as the original global memory tensor
- lds_block_global_base and lds_block_shape must have consistent ranks
- When LDS promotion is enabled, all LDS-related parameters must be specified
}];

let parameters = (ins
"bool":$use_lds_promotion,
StringRefParameter<"group identifier for shared LDS allocation">:$group_id,

// LDS block placement information
OptionalParameter<"::wave::WaveExprListAttr">:$lds_block_global_base,
OptionalParameter<"::wave::WaveExprListAttr">:$lds_block_shape,

// Vectorized access patterns - split into indices and vector sizes for simplicity
OptionalParameter<"::wave::WaveExprListAttr">:$lds_load_indices,
OptionalParameter<"::wave::WaveExprListAttr">:$lds_load_vector_sizes,
OptionalParameter<"::wave::WaveExprListAttr">:$global_store_indices
);

let assemblyFormat = [{
`<` `use_lds_promotion` `=` $use_lds_promotion
`,` `group_id` `=` $group_id
(`,` `lds_block_global_base` `=` $lds_block_global_base^)?
(`,` `lds_block_shape` `=` $lds_block_shape^)?
(`,` `lds_load_indices` `=` $lds_load_indices^)?
(`,` `lds_load_vector_sizes` `=` $lds_load_vector_sizes^)?
(`,` `global_store_indices` `=` $global_store_indices^)? `>`
}];

let genVerifyDecl = 1;

let extraClassDeclaration = [{
/// Check if LDS promotion is enabled.
bool shouldUseLdsPromotion() const { return getUseLdsPromotion(); }

/// Check if complete LDS promotion information is provided.
bool hasCompleteLdsPromotionInfo() const {
return shouldUseLdsPromotion() &&
getLdsBlockGlobalBase() &&
getLdsBlockShape() &&
getLdsLoadIndices() &&
getLdsLoadVectorSizes() &&
getGlobalStoreIndices();
}

/// Check if any LDS promotion parameters are specified.
bool hasAnyLdsPromotionParams() const {
return getLdsBlockGlobalBase() ||
getLdsBlockShape() ||
getLdsLoadIndices() ||
getLdsLoadVectorSizes() ||
getGlobalStoreIndices();
}

/// Get the rank/dimensionality of the LDS block.
unsigned getLdsBlockRank() const {
if (!getLdsBlockShape()) return 0;
return getLdsBlockShape().getRank();
}

/// Get the rank of the global base address expression.
unsigned getGlobalBaseRank() const {
if (!getLdsBlockGlobalBase()) return 0;
return getLdsBlockGlobalBase().getRank();
}

/// Get the rank of the LDS load indices.
unsigned getLdsLoadIndicesRank() const {
if (!getLdsLoadIndices()) return 0;
return getLdsLoadIndices().getRank();
}

/// Get the rank of the global store indices.
unsigned getGlobalStoreIndicesRank() const {
if (!getGlobalStoreIndices()) return 0;
return getGlobalStoreIndices().getRank();
}
}];
}

#endif // WATER_DIALECT_WAVE_WAVEATTRS
4 changes: 3 additions & 1 deletion water/include/water/Dialect/Wave/IR/WaveOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ def WriteOp : WaveOp<"write", [
Arg<OptionalAttr<I64Attr>,
"Number of elements processed by each thread">:$elements_per_thread,
Arg<OptionalAttr<WaveReadWriteBoundsAttr>,
"Bound expressions for each symbolic dimension">:$bounds
"Bound expressions for each symbolic dimension">:$bounds,
Arg<OptionalAttr<WaveMemoryAccessPatternAttr>,
"Memory access pattern controlling LDS promotion">:$memory_access_pattern
), commonArguments);

let assemblyFormat =
Expand Down
135 changes: 135 additions & 0 deletions water/lib/Dialect/Wave/IR/WaveAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,141 @@ DeviceConstraintAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

//===----------------------------------------------------------------------===//
// WaveMemoryAccessPatternAttr
//===----------------------------------------------------------------------===//

LogicalResult WaveMemoryAccessPatternAttr::verify(
function_ref<InFlightDiagnostic()> emitError, bool use_lds_promotion,
StringRef group_id, WaveExprListAttr lds_block_global_base,
WaveExprListAttr lds_block_shape, WaveExprListAttr lds_load_indices,
WaveExprListAttr lds_load_vector_sizes,
WaveExprListAttr global_store_indices) {
// Validate group_id is not empty
if (group_id.empty()) {
return emitError() << "group_id cannot be empty";
}

// When LDS promotion is disabled, no LDS-related parameters should be
// specified
if (!use_lds_promotion) {
if (lds_block_global_base || lds_block_shape || lds_load_indices ||
lds_load_vector_sizes || global_store_indices) {
return emitError() << "LDS promotion parameters should not be specified "
"when use_lds_promotion=false";
}
return success();
}

// When LDS promotion is enabled, validate completeness and consistency
bool hasLdsBase = static_cast<bool>(lds_block_global_base);
bool hasLdsShape = static_cast<bool>(lds_block_shape);
bool hasLdsLoadIndices = static_cast<bool>(lds_load_indices);
bool hasLdsLoadVectorSizes = static_cast<bool>(lds_load_vector_sizes);
bool hasGlobalStoreIndices = static_cast<bool>(global_store_indices);

// Check for partial specification - either all or none should be provided
if (hasLdsBase || hasLdsShape || hasLdsLoadIndices || hasLdsLoadVectorSizes ||
hasGlobalStoreIndices) {
if (!hasLdsBase || !hasLdsShape || !hasLdsLoadIndices ||
!hasLdsLoadVectorSizes || !hasGlobalStoreIndices) {
return emitError() << "when LDS promotion is enabled, all LDS parameters "
"must be specified: "
"lds_block_global_base, lds_block_shape, "
"lds_load_indices, lds_load_vector_sizes, "
"global_store_indices";
}
}

// If all LDS parameters are provided, perform detailed validation
if (hasLdsBase && hasLdsShape && hasLdsLoadIndices && hasLdsLoadVectorSizes &&
hasGlobalStoreIndices) {

// Validate that lds_block_global_base and lds_block_shape have consistent
// ranks
unsigned ldsBaseRank = lds_block_global_base.getRank();
unsigned ldsShapeRank = lds_block_shape.getRank();

if (ldsBaseRank != ldsShapeRank) {
return emitError() << "lds_block_global_base rank (" << ldsBaseRank
<< ") must match lds_block_shape rank ("
<< ldsShapeRank << ")";
}

// Validate that load indices and vector sizes have consistent ranks
unsigned ldsLoadIndicesRank = lds_load_indices.getRank();
unsigned ldsLoadVectorSizesRank = lds_load_vector_sizes.getRank();
unsigned globalStoreIndicesRank = global_store_indices.getRank();

if (ldsLoadIndicesRank != ldsLoadVectorSizesRank) {
return emitError() << "lds_load_indices rank (" << ldsLoadIndicesRank
<< ") must match lds_load_vector_sizes rank ("
<< ldsLoadVectorSizesRank << ")";
}

if (ldsBaseRank != ldsLoadIndicesRank) {
return emitError() << "LDS block rank (" << ldsBaseRank
<< ") must match LDS load indices rank ("
<< ldsLoadIndicesRank << ")";
}

if (ldsBaseRank != globalStoreIndicesRank) {
return emitError() << "LDS block rank (" << ldsBaseRank
<< ") must match global store indices rank ("
<< globalStoreIndicesRank << ")";
}

// Validate that all symbols are WaveSymbolAttr or WaveIndexSymbolAttr
if (!llvm::all_of(lds_block_global_base.getSymbols(),
llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
return emitError() << "lds_block_global_base must only contain "
"WaveSymbolAttr or WaveIndexSymbolAttr";
}

if (!llvm::all_of(lds_block_shape.getSymbols(),
llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
return emitError() << "lds_block_shape must only contain WaveSymbolAttr "
"or WaveIndexSymbolAttr";
}

if (!llvm::all_of(lds_load_indices.getSymbols(),
llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
return emitError() << "lds_load_indices must only contain WaveSymbolAttr "
"or WaveIndexSymbolAttr";
}

if (!llvm::all_of(lds_load_vector_sizes.getSymbols(),
llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
return emitError() << "lds_load_vector_sizes must only contain "
"WaveSymbolAttr or WaveIndexSymbolAttr";
}

if (!llvm::all_of(global_store_indices.getSymbols(),
llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr>)) {
return emitError() << "global_store_indices must only contain "
"WaveSymbolAttr or WaveIndexSymbolAttr";
}

// Validate that mappings have at least one dimension
if (ldsBaseRank == 0) {
return emitError() << "LDS block must have at least one dimension";
}

// Note: We cannot validate that the ranks match the original global memory
// tensor rank here because this attribute verification doesn't have access
// to the WriteOp's memory operand. This validation should be performed in
// the WriteOp's verifier where both the attribute and the memory operand
// type are available.
//
// Additionally, data coverage verification (ensuring that the collective
// workgroup access pattern covers exactly the same elements before and
// after LDS promotion) should be performed in the WriteOp verifier where
// access to the original index mapping is available.
}

return success();
}

void wave::WaveDialect::registerAttributes() {
addAttributes<
#define GET_ATTRDEF_LIST
Expand Down
Loading
Loading