Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions water/include/water/c/Dialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAWaveSymbolAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute
mlirWaveSymbolAttrGet(MlirContext mlirCtx, MlirStringRef symbolName);

/// Creates a new WaveSymbolAttr with the given symbol name, verifying
/// construction invariants.
MLIR_CAPI_EXPORTED MlirAttribute
mlirWaveSymbolAttrGetChecked(MlirLocation loc, MlirStringRef symbolName);

/// Returns the typeID of a WaveSymbolAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirWaveSymbolAttrGetTypeID();

Expand Down Expand Up @@ -83,6 +88,12 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirWaveIndexMappingAttrGet(
MlirContext mlirCtx, MlirAttribute *symbolNames, MlirAffineMap start,
MlirAffineMap step, MlirAffineMap stride);

/// Creates a new WaveIndexMappingAttr with the given start, step and stride
/// maps, verifying construction invariants.
MLIR_CAPI_EXPORTED MlirAttribute mlirWaveIndexMappingAttrGetChecked(
MlirLocation loc, MlirAttribute *symbolNames, MlirAffineMap start,
MlirAffineMap step, MlirAffineMap stride);

/// Returns the typeID of a WaveIndexMappingAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirWaveIndexMappingAttrGetTypeID();

Expand Down Expand Up @@ -236,6 +247,11 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAWaveExprListAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute
mlirWaveExprListAttrGet(MlirAttribute *symbolNames, MlirAffineMap map);

/// Creates a new WaveExprListAttr with the given map, verifying construction
/// invariants.
MLIR_CAPI_EXPORTED MlirAttribute mlirWaveExprListAttrGetChecked(
MlirLocation loc, MlirAttribute *symbolNames, MlirAffineMap map);

/// Returns the typeID of a WaveExprListAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirWaveExprListAttrGetTypeID();

Expand Down Expand Up @@ -269,6 +285,12 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirHardwareConstraintAttrGet(
unsigned *wavesPerBlock, MlirAttribute mmaType, MlirAttribute vectorShapes,
unsigned maxBitsPerLoad);

/// Creates a new HardwareConstraintAttr, verifying construction invariants.
MLIR_CAPI_EXPORTED MlirAttribute mlirHardwareConstraintAttrGetChecked(
MlirLocation loc, unsigned threadsPerWave, size_t wavesPerBlockSize,
unsigned *wavesPerBlock, MlirAttribute mmaType, MlirAttribute vectorShapes,
unsigned maxBitsPerLoad);

/// Returns the typeID of a HardwareConstraintAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirWHardwareConstraintAttrGetTypeID();

Expand All @@ -285,6 +307,11 @@ MLIR_CAPI_EXPORTED MlirAttribute
mlirDeviceConstraintAttrGet(MlirContext mlirCtx, MlirAttribute dim,
MlirAttribute tileSize, unsigned deviceDim);

/// Creates a new DeviceConstraintAttr, verifying construction invariants.
MLIR_CAPI_EXPORTED MlirAttribute
mlirDeviceConstraintAttrGetChecked(MlirLocation loc, MlirAttribute dim,
MlirAttribute tileSize, unsigned deviceDim);

/// Returns the typeID of a DeviceConstraintAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirDeviceConstraintAttrGetTypeID();

Expand All @@ -301,6 +328,11 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirWorkgroupConstraintAttrGet(
MlirContext mlirCtx, MlirAttribute dim, MlirAttribute tileSize,
MlirAttribute workgroupDim, bool primary);

/// Creates a new WorkgroupConstraintAttr, verifying construction invariants.
MLIR_CAPI_EXPORTED MlirAttribute mlirWorkgroupConstraintAttrGetChecked(
MlirLocation loc, MlirAttribute dim, MlirAttribute tileSize,
MlirAttribute workgroupDim, bool primary);

/// Returns the typeID of a WorkgroupConstraintAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirWorkgroupConstraintAttrGetTypeID();

Expand All @@ -315,6 +347,10 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAWaveConstraintAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute mlirWaveConstraintAttrGet(
MlirContext mlirCtx, MlirAttribute dim, MlirAttribute tileSize);

/// Creates a new WaveConstraintAttr, verifying construction invariants.
MLIR_CAPI_EXPORTED MlirAttribute mlirWaveConstraintAttrGetChecked(
MlirLocation loc, MlirAttribute dim, MlirAttribute tileSize);

/// Returns the typeID of a WaveConstraintAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirWaveConstraintAttrGetTypeID();

Expand All @@ -330,6 +366,10 @@ mlirAttributeIsATilingConstraintAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute mlirTilingConstraintAttrGet(
MlirContext mlirCtx, MlirAttribute dim, MlirAttribute tileSize);

/// Creates a new TilingConstraintAttr, verifying construction invariants.
MLIR_CAPI_EXPORTED MlirAttribute mlirTilingConstraintAttrGetChecked(
MlirLocation loc, MlirAttribute dim, MlirAttribute tileSize);

/// Returns the typeID of a TilingConstraintAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirTilingConstraintAttrGetTypeID();

Expand Down
179 changes: 179 additions & 0 deletions water/lib/CAPI/Dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ MlirAttribute mlirWaveSymbolAttrGet(MlirContext mlirCtx,
return wrap(wave::WaveSymbolAttr::get(ctx, symbolName));
}

MlirAttribute mlirWaveSymbolAttrGetChecked(MlirLocation loc,
MlirStringRef symbolNameStrRef) {
mlir::Location cppLoc = unwrap(loc);
mlir::MLIRContext *ctx = cppLoc.getContext();
llvm::StringRef symbolName = unwrap(symbolNameStrRef);
return wrap(wave::WaveSymbolAttr::getChecked(
[&] { return mlir::emitError(cppLoc); }, ctx, symbolName));
}

MlirTypeID mlirWaveSymbolAttrGetTypeID() {
return wrap(mlir::TypeID::get<wave::WaveSymbolAttr>());
}
Expand Down Expand Up @@ -99,6 +108,30 @@ MlirAttribute mlirWaveIndexMappingAttrGet(MlirContext mlirCtx,
unwrap(step), unwrap(stride)));
}

MlirAttribute mlirWaveIndexMappingAttrGetChecked(MlirLocation loc,
MlirAttribute *symbolNames,
MlirAffineMap start,
MlirAffineMap step,
MlirAffineMap stride) {
mlir::Location cppLoc = unwrap(loc);
mlir::MLIRContext *ctx = cppLoc.getContext();

// Convert C array of MlirAttribute to vector of WaveSymbolAttr.
unsigned numSymbols = mlirAffineMapGetNumSymbols(start);
llvm::SmallVector<mlir::Attribute> symbolAttrs = llvm::map_to_vector(
llvm::make_range(symbolNames, symbolNames + numSymbols),
[](MlirAttribute attr) { return unwrap(attr); });

// Needs explicit conversion to ArrayRef, otherwise we hit the
// Base::getChecked template that can be instantiated for a more specific
// type, but cannot be called from here.
auto attr = wave::WaveIndexMappingAttr::getChecked(
[&] { return mlir::emitError(cppLoc); }, ctx, llvm::ArrayRef(symbolAttrs),
unwrap(start), unwrap(step), unwrap(stride));

return wrap(attr);
}

MlirTypeID mlirWaveIndexMappingAttrGetTypeID() {
return wrap(mlir::TypeID::get<wave::WaveIndexMappingAttr>());
}
Expand Down Expand Up @@ -222,6 +255,25 @@ MlirAttribute mlirWaveExprListAttrGet(MlirAttribute *symbolNames,
return wrap(wave::WaveExprListAttr::get(ctx, symbolAttrs, unwrap(map)));
}

MlirAttribute mlirWaveExprListAttrGetChecked(MlirLocation loc,
MlirAttribute *symbolNames,
MlirAffineMap map) {
mlir::Location cppLoc = unwrap(loc);
mlir::MLIRContext *ctx = cppLoc.getContext();

unsigned numSymbols = mlirAffineMapGetNumSymbols(map);
llvm::SmallVector<mlir::Attribute> symbolAttrs = llvm::map_to_vector(
llvm::make_range(symbolNames, symbolNames + numSymbols),
[](MlirAttribute attr) { return unwrap(attr); });

// Explicitly convert to ArrayRef, otherwise we hit the Base::getChecked
// template that can be instantiated for a more specific type, but cannot be
// called form here.
return wrap(wave::WaveExprListAttr::getChecked(
[&] { return mlir::emitError(cppLoc); }, ctx, llvm::ArrayRef(symbolAttrs),
unwrap(map)));
}

MlirTypeID mlirWaveExprListAttrGetTypeID() {
return wrap(mlir::TypeID::get<wave::WaveExprListAttr>());
}
Expand Down Expand Up @@ -277,6 +329,32 @@ mlirHardwareConstraintAttrGet(MlirContext mlirCtx, unsigned threadsPerWave,
mmaTypeAttr, vectorShapesAttr, maxBitsPerLoad));
}

MlirAttribute mlirHardwareConstraintAttrGetChecked(
MlirLocation loc, unsigned threadsPerWave, size_t wavesPerBlockSize,
unsigned *wavesPerBlock, MlirAttribute mmaType, MlirAttribute vectorShapes,
unsigned maxBitsPerLoad) {
mlir::Location cppLoc = unwrap(loc);
mlir::MLIRContext *ctx = cppLoc.getContext();
auto mmaTypeAttr =
llvm::dyn_cast_or_null<wave::WaveMmaKindAttr>(unwrap(mmaType));
if (!mmaTypeAttr && !mlirAttributeIsNull(mmaType)) {
mlir::emitError(unwrap(loc)) << "expected a WaveMmaKindAttr for mmaType";
return wrap(mlir::Attribute());
}
auto vectorShapesAttr =
llvm::dyn_cast_or_null<mlir::DictionaryAttr>(unwrap(vectorShapes));
if (!vectorShapesAttr && !mlirAttributeIsNull(vectorShapes)) {
mlir::emitError(unwrap(loc))
<< "expected a DictionaryAttr for vectorShapes";
return wrap(mlir::Attribute());
}

return wrap(wave::HardwareConstraintAttr::getChecked(
[&] { return mlir::emitError(cppLoc); }, ctx, threadsPerWave,
llvm::ArrayRef(wavesPerBlock, wavesPerBlockSize), mmaTypeAttr,
vectorShapesAttr, maxBitsPerLoad));
}

MlirTypeID mlirWHardwareConstraintAttrGetTypeID() {
return wrap(mlir::TypeID::get<wave::HardwareConstraintAttr>());
}
Expand All @@ -301,6 +379,30 @@ MlirAttribute mlirDeviceConstraintAttrGet(MlirContext mlirCtx,
wave::DeviceConstraintAttr::get(ctx, dimAttr, tileSizeAttr, deviceDim));
}

MlirAttribute mlirDeviceConstraintAttrGetChecked(MlirLocation loc,
MlirAttribute dim,
MlirAttribute tileSize,
unsigned deviceDim) {
mlir::Location cppLoc = unwrap(loc);
mlir::MLIRContext *ctx = cppLoc.getContext();

auto dimAttr = llvm::dyn_cast<wave::WaveSymbolAttr>(unwrap(dim));
if (!dimAttr) {
mlir::emitError(cppLoc) << "expected a WaveSymbolAttr for dim";
return wrap(mlir::Attribute());
}

auto tileSizeAttr = llvm::dyn_cast<wave::WaveExprListAttr>(unwrap(tileSize));
if (!tileSizeAttr) {
mlir::emitError(cppLoc) << "expected a WaveExprListAttr for tileSize";
return wrap(mlir::Attribute());
}

return wrap(wave::DeviceConstraintAttr::getChecked(
[&] { return mlir::emitError(cppLoc); }, ctx, dimAttr, tileSizeAttr,
deviceDim));
}

MlirTypeID mlirDeviceConstraintAttrGetTypeID() {
return wrap(mlir::TypeID::get<wave::DeviceConstraintAttr>());
}
Expand Down Expand Up @@ -328,6 +430,39 @@ MlirAttribute mlirWorkgroupConstraintAttrGet(MlirContext mlirCtx,
workgroupDimAttr, primary));
}

MlirAttribute mlirWorkgroupConstraintAttrGetChecked(MlirLocation loc,
MlirAttribute dim,
MlirAttribute tileSize,
MlirAttribute workgroupDim,
bool primary) {
mlir::Location cppLoc = unwrap(loc);
mlir::MLIRContext *ctx = cppLoc.getContext();

auto dimAttr = llvm::dyn_cast<wave::WaveSymbolAttr>(unwrap(dim));
if (!dimAttr) {
mlir::emitError(cppLoc) << "expected a WaveSymbolAttr for dim";
return wrap(mlir::Attribute());
}

auto tileSizeAttr = llvm::dyn_cast<wave::WaveExprListAttr>(unwrap(tileSize));
if (!tileSizeAttr) {
mlir::emitError(cppLoc) << "expected a WaveExprListAttr for tileSize";
return wrap(mlir::Attribute());
}

auto workgroupDimAttr =
llvm::dyn_cast<wave::WaveWorkgroupDimAttr>(unwrap(workgroupDim));
if (!workgroupDimAttr) {
mlir::emitError(cppLoc)
<< "expected a WaveWorkgroupDimAttr for workgroupDim";
return wrap(mlir::Attribute());
}

return wrap(wave::WorkgroupConstraintAttr::getChecked(
[&] { return mlir::emitError(cppLoc); }, ctx, dimAttr, tileSizeAttr,
workgroupDimAttr, primary));
}

MlirTypeID mlirWorkgroupConstraintAttrGetTypeID() {
return wrap(mlir::TypeID::get<wave::WorkgroupConstraintAttr>());
}
Expand All @@ -348,6 +483,28 @@ MlirAttribute mlirWaveConstraintAttrGet(MlirContext mlirCtx, MlirAttribute dim,
return wrap(wave::WaveConstraintAttr::get(ctx, dimAttr, tileSizeAttr));
}

MlirAttribute mlirWaveConstraintAttrGetChecked(MlirLocation loc,
MlirAttribute dim,
MlirAttribute tileSize) {
mlir::Location cppLoc = unwrap(loc);
mlir::MLIRContext *ctx = cppLoc.getContext();

auto dimAttr = llvm::dyn_cast<wave::WaveSymbolAttr>(unwrap(dim));
if (!dimAttr) {
mlir::emitError(cppLoc) << "expected a WaveSymbolAttr for dim";
return wrap(mlir::Attribute());
}

auto tileSizeAttr = llvm::dyn_cast<wave::WaveExprListAttr>(unwrap(tileSize));
if (!tileSizeAttr) {
mlir::emitError(cppLoc) << "expected a WaveExprListAttr for tileSize";
return wrap(mlir::Attribute());
}

return wrap(wave::WaveConstraintAttr::getChecked(
[&] { return mlir::emitError(cppLoc); }, ctx, dimAttr, tileSizeAttr));
}

MlirTypeID mlirWaveConstraintAttrGetTypeID() {
return wrap(mlir::TypeID::get<wave::WaveConstraintAttr>());
}
Expand All @@ -370,6 +527,28 @@ MlirAttribute mlirTilingConstraintAttrGet(MlirContext mlirCtx,
return wrap(wave::TilingConstraintAttr::get(ctx, dimAttr, tileSizeAttr));
}

MlirAttribute mlirTilingConstraintAttrGetChecked(MlirLocation loc,
MlirAttribute dim,
MlirAttribute tileSize) {
mlir::Location cppLoc = unwrap(loc);
mlir::MLIRContext *ctx = cppLoc.getContext();

auto dimAttr = llvm::dyn_cast<wave::WaveSymbolAttr>(unwrap(dim));
if (!dimAttr) {
mlir::emitError(cppLoc) << "expected a WaveSymbolAttr for dim";
return wrap(mlir::Attribute());
}

auto tileSizeAttr = llvm::dyn_cast<wave::WaveExprListAttr>(unwrap(tileSize));
if (!tileSizeAttr) {
mlir::emitError(cppLoc) << "expected a WaveExprListAttr for tileSize";
return wrap(mlir::Attribute());
}

return wrap(wave::TilingConstraintAttr::getChecked(
[&] { return mlir::emitError(cppLoc); }, ctx, dimAttr, tileSizeAttr));
}

MlirTypeID mlirTilingConstraintAttrGetTypeID() {
return wrap(mlir::TypeID::get<wave::TilingConstraintAttr>());
}
Expand Down
Loading