diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index 081a9b8cad8d6..5e523ec428aef 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -17,13 +17,9 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TensorEncoding.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "llvm/ADT/bit.h" - //===----------------------------------------------------------------------===// // // Type aliases to help code be more self-documenting. Unfortunately @@ -45,40 +41,6 @@ using Level = uint64_t; /// including the value `ShapedType::kDynamic` (for shapes). using Size = int64_t; -/// A simple wrapper to encode a bitset of defined (at most 64) levels. -class LevelSet { - uint64_t bits = 0; - -public: - LevelSet() = default; - explicit LevelSet(uint64_t bits) : bits(bits) {} - operator uint64_t() const { return bits; } - - LevelSet &set(unsigned i) { - assert(i < 64); - bits |= 1 << i; - return *this; - } - - LevelSet &operator|=(LevelSet lhs) { - bits |= static_cast(lhs); - return *this; - } - - LevelSet &lshift(unsigned offset) { - bits = bits << offset; - return *this; - } - - bool operator[](unsigned i) const { - assert(i < 64); - return (bits & (1 << i)) != 0; - } - - unsigned count() const { return llvm::popcount(bits); } - bool empty() const { return bits == 0; } -}; - } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td index d5398a98f5b17..4a9b9169ae4b8 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -19,21 +19,6 @@ class SparseTensor_Attr traits = []> : AttrDef; -//===----------------------------------------------------------------------===// -// A simple bitset attribute wrapped over a single int64_t to encode a set of -// sparse tensor levels. -//===----------------------------------------------------------------------===// - -def LevelSetAttr : - TypedAttrBase< - I64, "IntegerAttr", - And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, - CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>, - "LevelSet attribute"> { - let returnType = [{::mlir::sparse_tensor::LevelSet}]; - let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}]; -} - //===----------------------------------------------------------------------===// // These attributes are just like `IndexAttr` except that they clarify whether // the index refers to a dimension (an axis of the semantic tensor) or a level diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index b43d716d5e864..0cfc64f9988a0 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -15,8 +15,6 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td" include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/LoopLikeInterface.td" //===----------------------------------------------------------------------===// // Base class. @@ -1279,7 +1277,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator, ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp", - "ForeachOp", "IterateOp"]>]>, + "ForeachOp"]>]>, Arguments<(ins Variadic:$results)> { let summary = "Yield from sparse_tensor set-like operations"; let description = [{ @@ -1432,154 +1430,6 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach", let hasVerifier = 1; } -//===----------------------------------------------------------------------===// -// Sparse Tensor Iteration Operations. -//===----------------------------------------------------------------------===// - -def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space", - [Pure, DeclareOpInterfaceMethods]> { - - let arguments = (ins AnySparseTensor:$tensor, - Optional:$parentIter, - LevelAttr:$loLvl, LevelAttr:$hiLvl); - - let results = (outs AnySparseIterSpace:$resultSpace); - - let summary = "Extract an iteration space from a sparse tensor between certain levels"; - let description = [{ - Extracts a `!sparse_tensor.iter_space` from a sparse tensor between - certian (consecutive) levels. - - `tensor`: the input sparse tensor that defines the iteration space. - `parentIter`: the iterator for the previous level, at which the iteration space - at the current levels will be extracted. - `loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that - the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the - iteration space. - - Example: - ```mlir - // Extracts a 1-D iteration space from a COO tensor at level 1. - %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1 - : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0> - ``` - }]; - - - let extraClassDeclaration = [{ - std::pair getLvlRange() { - return std::make_pair(getLoLvl(), getHiLvl()); - } - unsigned getSpaceDim() { - return getHiLvl() - getLoLvl(); - } - ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() { - return getResultSpace().getType().getLvlTypes(); - } - }]; - - let hasVerifier = 1; - let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom($loLvl, $hiLvl) " - " attr-dict `:` type($tensor) (`,` type($parentIter)^)?"; -} - -def IterateOp : SparseTensor_Op<"iterate", - [RecursiveMemoryEffects, RecursivelySpeculatable, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> { - - let arguments = (ins AnySparseIterSpace:$iterSpace, - Variadic:$initArgs, - LevelSetAttr:$crdUsedLvls); - let results = (outs Variadic:$results); - let regions = (region SizedRegion<1>:$region); - - let summary = "Iterate over a sparse iteration space"; - let description = [{ - The `sparse_tensor.iterate` operations represents a loop over the - provided iteration space extracted from a specific sparse tensor. - The operation defines an SSA value for a sparse iterator that points - to the current stored element in the sparse tensor and SSA values - for coordinates of the stored element. The coordinates are always - converted to `index` type despite of the underlying sparse tensor - storage. When coordinates are not used, the SSA values can be skipped - by `_` symbols, which usually leads to simpler generated code after - sparsification. For example: - - ```mlir - // The coordinate for level 0 is not used when iterating over a 2-D - // iteration space. - %sparse_tensor.iterate %iterator in %space at(_, %crd_1) - : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2> - ``` - - `sparse_tensor.iterate` can also operate on loop-carried variables - and returns the final values after loop termination. - The initial values of the variables are passed as additional SSA operands - to the iterator SSA value and used coordinate SSA values mentioned - above. The operation region has an argument for the iterator, variadic - arguments for specified (used) coordiates and followed by one argument - for each loop-carried variable, representing the value of the variable - at the current iteration. - The body region must contain exactly one block that terminates with - `sparse_tensor.yield`. - - `sparse_tensor.iterate` results hold the final values after the last - iteration. If the `sparse_tensor.iterate` defines any values, a yield - must be explicitly present. - The number and types of the `sparse_tensor.iterate` results must match - the initial values in the iter_args binding and the yield operands. - - - A nested `sparse_tensor.iterate` example that prints all the coordinates - stored in the sparse input: - - ```mlir - func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) { - // Iterates over the first level of %sp - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> - %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0) - : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> { - // Iterates over the second level of %sp - %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 - : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> - %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1) - : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> { - vector.print %crd0 : index - vector.print %crd1 : index - } - } - } - - ``` - }]; - - let extraClassDeclaration = [{ - unsigned getSpaceDim() { - return getIterSpace().getType().getSpaceDim(); - } - BlockArgument getIterator() { - return getRegion().getArguments().front(); - } - Block::BlockArgListType getCrds() { - // The first block argument is iterator, the remaining arguments are - // referenced coordinates. - return getRegion().getArguments().slice(1, getCrdUsedLvls().count()); - } - unsigned getNumRegionIterArgs() { - return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count(); - } - }]; - - let hasVerifier = 1; - let hasRegionVerifier = 1; - let hasCustomAssemblyFormat = 1; -} - //===----------------------------------------------------------------------===// // Sparse Tensor Debugging and Test-Only Operations. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td index 264a0a5b3bee6..185cff46ae25d 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td @@ -72,99 +72,4 @@ def SparseTensorStorageSpecifier : Type($_self)">, "metadata", "::mlir::sparse_tensor::StorageSpecifierType">; -//===----------------------------------------------------------------------===// -// Sparse Tensor Iteration Types. -//===----------------------------------------------------------------------===// - -def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> { - let mnemonic = "iter_space"; - - let description = [{ - A sparse iteration space that represents an abstract N-D (sparse) iteration space - extracted from a sparse tensor. - - Examples: - - ```mlir - // An iteration space extracted from a CSR tensor between levels [0, 2). - !iter_space<#CSR, lvls = 0 to 2> - ``` - }]; - - let parameters = (ins - SparseTensorEncodingAttr : $encoding, - "Level" : $loLvl, - "Level" : $hiLvl - ); - - let extraClassDeclaration = [{ - /// The the dimension of the iteration space. - unsigned getSpaceDim() const { - return getHiLvl() - getLoLvl(); - } - - /// Get the level types for the iteration space. - ArrayRef getLvlTypes() const { - return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim()); - } - - /// Whether the iteration space is unique (i.e., no duplicated coordinate). - bool isUnique() { - return !getLvlTypes().back().isa(); - } - - /// Get the corresponding iterator type. - ::mlir::sparse_tensor::IteratorType getIteratorType() const; - }]; - - let assemblyFormat="`<` $encoding `,` `lvls` `=` custom($loLvl, $hiLvl) `>`"; -} - -def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> { - let mnemonic = "iterator"; - - let description = [{ - An iterator that points to the current element in the corresponding iteration space. - - Examples: - - ```mlir - // An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>` - !iterator<#CSR, lvls = 0 to 2> - ``` - }]; - - let parameters = (ins - SparseTensorEncodingAttr : $encoding, - "Level" : $loLvl, - "Level" : $hiLvl - ); - - let extraClassDeclaration = [{ - /// Get the corresponding iteration space type. - ::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const; - - unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); } - ArrayRef getLvlTypes() const { return getIterSpaceType().getLvlTypes(); } - bool isUnique() { return getIterSpaceType().isUnique(); } - }]; - - let assemblyFormat="`<` $encoding `,` `lvls` `=` custom($loLvl, $hiLvl) `>`"; -} - -def IsSparseSparseIterSpaceTypePred - : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">; - -def IsSparseSparseIteratorTypePred - : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">; - -def AnySparseIterSpace - : Type; - -def AnySparseIterator - : Type; - - #endif // SPARSETENSOR_TYPES diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 36908def09f40..e9058394d33da 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -30,14 +30,6 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" -// Forward declarations, following custom print/parsing methods are referenced -// by the generated code for SparseTensorTypes.td. -static mlir::ParseResult parseLevelRange(mlir::AsmParser &, - mlir::sparse_tensor::Level &, - mlir::sparse_tensor::Level &); -static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, - mlir::sparse_tensor::Level); - #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" @@ -1961,363 +1953,6 @@ LogicalResult SortOp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// Sparse Tensor Iteration Operations. -//===----------------------------------------------------------------------===// - -IterSpaceType IteratorType::getIterSpaceType() const { - return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(), - getHiLvl()); -} - -IteratorType IterSpaceType::getIteratorType() const { - return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); -} - -/// Parses a level range in the form "$lo `to` $hi" -/// or simply "$lo" if $hi - $lo = 1 -static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, - Level &lvlHi) { - if (parser.parseInteger(lvlLo)) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("to"))) { - if (parser.parseInteger(lvlHi)) - return failure(); - } else { - lvlHi = lvlLo + 1; - } - - if (lvlHi <= lvlLo) - parser.emitError(parser.getNameLoc(), - "expect larger level upper bound than lower bound"); - - return success(); -} - -/// Parses a level range in the form "$lo `to` $hi" -/// or simply "$lo" if $hi - $lo = 1 -static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr, - IntegerAttr &lvlHiAttr) { - Level lvlLo, lvlHi; - if (parseLevelRange(parser, lvlLo, lvlHi)) - return failure(); - - lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo); - lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi); - return success(); -} - -/// Prints a level range in the form "$lo `to` $hi" -/// or simply "$lo" if $hi - $lo = 1 -static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { - - if (lo + 1 == hi) - p << lo; - else - p << lo << " to " << hi; -} - -/// Prints a level range in the form "$lo `to` $hi" -/// or simply "$lo" if $hi - $lo = 1 -static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, - IntegerAttr lvlHi) { - unsigned lo = lvlLo.getValue().getZExtValue(); - unsigned hi = lvlHi.getValue().getZExtValue(); - printLevelRange(p, lo, hi); -} - -static ParseResult -parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state, - SmallVectorImpl &iterators, - SmallVectorImpl &iterArgs) { - SmallVector spaces; - SmallVector initArgs; - - // Parses "%iters, ... in %spaces, ..." - if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") || - parser.parseOperandList(spaces)) - return failure(); - - if (iterators.size() != spaces.size()) - return parser.emitError( - parser.getNameLoc(), - "mismatch in number of sparse iterators and sparse spaces"); - - // Parse "at(%crd0, _, ...)" - LevelSet crdUsedLvlSet; - bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at")); - unsigned lvlCrdCnt = 0; - if (hasUsedCrds) { - ParseResult crdList = parser.parseCommaSeparatedList( - OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { - if (parser.parseOptionalKeyword("_")) { - if (parser.parseArgument(iterArgs.emplace_back())) - return failure(); - // Always use IndexType for the coordinate. - crdUsedLvlSet.set(lvlCrdCnt); - iterArgs.back().type = parser.getBuilder().getIndexType(); - } - lvlCrdCnt += 1; - return success(); - }); - if (failed(crdList)) { - return parser.emitError( - parser.getNameLoc(), - "expecting SSA value or \"_\" for level coordinates"); - } - } - // Set the CrdUsedLvl bitset. - state.addAttribute("crdUsedLvls", - parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet)); - - // Parse "iter_args(%arg = %init, ...)" - bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); - if (hasIterArgs) - if (parser.parseAssignmentList(iterArgs, initArgs)) - return failure(); - - SmallVector iterSpaceTps; - // parse ": sparse_tensor.iter_space -> ret" - if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) - return failure(); - if (iterSpaceTps.size() != spaces.size()) - return parser.emitError(parser.getNameLoc(), - "mismatch in number of iteration space operands " - "and iteration space types"); - - for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) { - IterSpaceType spaceTp = llvm::dyn_cast(tp); - if (!spaceTp) - return parser.emitError(parser.getNameLoc(), - "expected sparse_tensor.iter_space type for " - "iteration space operands"); - if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt) - return parser.emitError(parser.getNameLoc(), - "mismatch in number of iteration space dimension " - "and specified coordinates"); - it.type = spaceTp.getIteratorType(); - } - - if (hasIterArgs) - if (parser.parseArrowTypeList(state.types)) - return failure(); - - // Resolves input operands. - if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(), - state.operands)) - return failure(); - - if (hasIterArgs) { - unsigned numCrds = crdUsedLvlSet.count(); - // Strip off leading args that used for coordinates. - MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds); - if (args.size() != initArgs.size() || args.size() != state.types.size()) { - return parser.emitError( - parser.getNameLoc(), - "mismatch in number of iteration arguments and return values"); - } - - for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) { - it.type = tp; - if (parser.resolveOperand(init, tp, state.operands)) - return failure(); - } - } - return success(); -} - -LogicalResult ExtractIterSpaceOp::inferReturnTypes( - MLIRContext *ctx, std::optional loc, ValueRange ops, - DictionaryAttr attr, OpaqueProperties prop, RegionRange region, - SmallVectorImpl &ret) { - - ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region); - SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); - ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(), - adaptor.getHiLvl())); - return success(); -} - -LogicalResult ExtractIterSpaceOp::verify() { - if (getLoLvl() >= getHiLvl()) - return emitOpError("expected smaller level low than level high"); - - TypedValue pIter = getParentIter(); - if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) { - return emitOpError( - "parent iterator should be specified iff level lower bound equals 0"); - } - - if (pIter) { - IterSpaceType spaceTp = getResultSpace().getType(); - if (pIter.getType().getEncoding() != spaceTp.getEncoding()) - return emitOpError( - "mismatch in parent iterator encoding and iteration space encoding."); - - if (spaceTp.getLoLvl() != pIter.getType().getHiLvl()) - return emitOpError("parent iterator should be used to extract an " - "iteration space from a consecutive level."); - } - - return success(); -} - -ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::Argument iterator; - OpAsmParser::UnresolvedOperand iterSpace; - - SmallVector iters, iterArgs; - if (parseSparseSpaceLoop(parser, result, iters, iterArgs)) - return failure(); - if (iters.size() != 1) - return parser.emitError(parser.getNameLoc(), - "expected only one iterator/iteration space"); - - iters.append(iterArgs); - Region *body = result.addRegion(); - if (parser.parseRegion(*body, iters)) - return failure(); - - IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); - - // Parse the optional attribute list. - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - return success(); -} - -/// Prints the initialization list in the form of -/// (%inner = %outer, %inner2 = %outer2, <...>) -/// where 'inner' values are assumed to be region arguments and 'outer' values -/// are regular SSA values. -static void printInitializationList(OpAsmPrinter &p, - Block::BlockArgListType blocksArgs, - ValueRange initializers, - StringRef prefix = "") { - assert(blocksArgs.size() == initializers.size() && - "expected same length of arguments and initializers"); - if (initializers.empty()) - return; - - p << prefix << '('; - llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) { - p << std::get<0>(it) << " = " << std::get<1>(it); - }); - p << ")"; -} - -static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim, - Block::BlockArgListType blocksArgs, - LevelSet crdUsedLvls) { - if (crdUsedLvls.empty()) - return; - - p << " at("; - for (unsigned i = 0; i < spaceDim; i++) { - if (crdUsedLvls[i]) { - p << blocksArgs.front(); - blocksArgs = blocksArgs.drop_front(); - } else { - p << "_"; - } - if (i != spaceDim - 1) - p << ", "; - } - assert(blocksArgs.empty()); - p << ")"; -} - -void IterateOp::print(OpAsmPrinter &p) { - p << " " << getIterator() << " in " << getIterSpace(); - printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls()); - printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args"); - - p << " : " << getIterSpace().getType() << " "; - if (!getInitArgs().empty()) - p << "-> (" << getInitArgs().getTypes() << ") "; - - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/!getInitArgs().empty()); -} - -LogicalResult IterateOp::verify() { - if (getInitArgs().size() != getNumResults()) { - return emitOpError( - "mismatch in number of loop-carried values and defined values"); - } - return success(); -} - -LogicalResult IterateOp::verifyRegions() { - if (getIterator().getType() != getIterSpace().getType().getIteratorType()) - return emitOpError("mismatch in iterator and iteration space type"); - if (getNumRegionIterArgs() != getNumResults()) - return emitOpError( - "mismatch in number of basic block args and defined values"); - - auto initArgs = getInitArgs(); - auto iterArgs = getRegionIterArgs(); - auto yieldVals = getYieldedValues(); - auto opResults = getResults(); - if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), - opResults.size()})) { - return emitOpError() << "number mismatch between iter args and results."; - } - - for (auto [i, init, iter, yield, ret] : - llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { - if (init.getType() != ret.getType()) - return emitOpError() << "types mismatch between " << i - << "th iter operand and defined value"; - if (iter.getType() != ret.getType()) - return emitOpError() << "types mismatch between " << i - << "th iter region arg and defined value"; - if (yield.getType() != ret.getType()) - return emitOpError() << "types mismatch between " << i - << "th yield value and defined value"; - } - - return success(); -} - -/// IterateOp implemented OpInterfaces' methods. -SmallVector IterateOp::getLoopRegions() { return {&getRegion()}; } - -MutableArrayRef IterateOp::getInitsMutable() { - return getInitArgsMutable(); -} - -Block::BlockArgListType IterateOp::getRegionIterArgs() { - return getRegion().getArguments().take_back(getNumRegionIterArgs()); -} - -std::optional> IterateOp::getYieldedValuesMutable() { - return cast( - getRegion().getBlocks().front().getTerminator()) - .getResultsMutable(); -} - -std::optional IterateOp::getLoopResults() { return getResults(); } - -OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { - return getInitArgs(); -} - -void IterateOp::getSuccessorRegions(RegionBranchPoint point, - SmallVectorImpl ®ions) { - // Both the operation itself and the region may be branching into the body or - // back into the operation itself. - regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - // It is possible for loop not to enter the body. - regions.push_back(RegionSuccessor(getResults())); -} - -//===----------------------------------------------------------------------===// -// Sparse Tensor Dialect Setups. -//===----------------------------------------------------------------------===// - /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index b13024cd4ed99..7f5c05190fc9a 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -1012,142 +1012,3 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) { sparse_tensor.print %arg0 : tensor<10x10xf64> return } - -// ----- - -#COO = #sparse_tensor.encoding<{ - map = (i, j) -> ( - i : compressed(nonunique), - j : singleton(soa) - ) -}> - -func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) { - // expected-error@+1 {{'sparse_tensor.extract_iteration_space' expect larger level upper bound than lower bound}} - %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2> - return -} - -// ----- - -#COO = #sparse_tensor.encoding<{ - map = (i, j) -> ( - i : compressed(nonunique), - j : singleton(soa) - ) -}> - -func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) { - // expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}} - %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0> - return -} - -// ----- - -#COO = #sparse_tensor.encoding<{ - map = (i, j) -> ( - i : compressed(nonunique), - j : singleton(soa) - ) -}> - -func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) { - // expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}} - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO> - return -} - -// ----- - -#COO = #sparse_tensor.encoding<{ - map = (i, j) -> ( - i : compressed(nonunique), - j : singleton(soa) - ) -}> - -#CSR = #sparse_tensor.encoding<{ - map = (i, j) -> ( - i : dense, - j : compressed - ) -}> - -func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) { - // expected-error@+1 {{'sparse_tensor.extract_iteration_space' op mismatch in parent iterator encoding and iteration space encoding.}} - %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0> - return -} - -// ----- - -#COO = #sparse_tensor.encoding<{ - map = (i, j) -> ( - i : compressed(nonunique), - j : singleton(soa) - ) -}> - -func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) { - // expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be used to extract an iteration space from a consecutive level.}} - %l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0> - return -} - - -// ----- - -#COO = #sparse_tensor.encoding<{ - map = (i, j) -> ( - i : compressed(nonunique), - j : singleton(soa) - ) -}> - -func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index { - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> - // expected-error @+1 {{'sparse_tensor.iterate' op different number of region iter_args and yielded values: 2 != 1}} - %r1, %r2 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i, %sj = %j): !sparse_tensor.iter_space<#COO, lvls = 0> -> (index, index) { - sparse_tensor.yield %si : index - } - return %r1 : index -} - -// ----- - -#COO = #sparse_tensor.encoding<{ - map = (i, j) -> ( - i : compressed(nonunique), - j : singleton(soa) - ) -}> - -// expected-note@+1 {{prior use here}} -func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 { - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> - // expected-error @+1 {{use of value '%i' expects different type than prior uses: 'f32' vs 'index'}} - %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> f32 { - sparse_tensor.yield %outer : f32 - } - return %r1 : f32 -} - -// ----- - -#COO = #sparse_tensor.encoding<{ - map = (i, j) -> ( - i : compressed(nonunique), - j : singleton(soa) - ) -}> - -func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index { - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> - // expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have different type: 'index' != 'f32'}} - %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index { - %y = arith.constant 1.0 : f32 - sparse_tensor.yield %y : f32 - } - return %r1 : index -} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir index e9a898f16b41d..12f69c1d37b9c 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -738,56 +738,3 @@ func.func @sparse_has_runtime() -> i1 { %has_runtime = sparse_tensor.has_runtime_library return %has_runtime : i1 } - -// ----- - -#COO = #sparse_tensor.encoding<{ - map = (i, j) -> ( - i : compressed(nonunique), - j : singleton(soa) - ) -}> - -// CHECK-LABEL: func.func @sparse_extract_iter_space( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>, -// CHECK-SAME: %[[VAL_1:.*]]: !sparse_tensor.iterator<#sparse{{[0-9]*}}, lvls = 0>) -// CHECK: %[[VAL_2:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 -// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] at %[[VAL_1]] lvls = 1 -// CHECK: return %[[VAL_2]], %[[VAL_3]] : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0>, !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 1> -// CHECK: } -func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) - -> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) { - // Extracting the iteration space for the first level needs no parent iterator. - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> - // Extracting the iteration space for the second level needs a parent iterator. - %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0> - return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1> -} - - -// ----- - -#COO = #sparse_tensor.encoding<{ - map = (i, j) -> ( - i : compressed(nonunique), - j : singleton(soa) - ) -}> - -// CHECK-LABEL: func.func @sparse_iterate( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>, -// CHECK-SAME: %[[VAL_1:.*]]: index, -// CHECK-SAME: %[[VAL_2:.*]]: index) -> index { -// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 : tensor<4x8xf32, #sparse{{[0-9]*}}> -// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> (index) { -// CHECK: sparse_tensor.yield %[[VAL_7]] : index -// CHECK: } -// CHECK: return %[[VAL_4]] : index -// CHECK: } -func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index { - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> - %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index { - sparse_tensor.yield %outer : index - } - return %r1 : index -} diff --git a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir b/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir deleted file mode 100644 index e7158d04b37fe..0000000000000 --- a/mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: mlir-opt %s --loop-invariant-code-motion | FileCheck %s - -#CSR = #sparse_tensor.encoding<{ - map = (i, j) -> ( - i : dense, - j : compressed - ) -}> - -// Make sure that pure instructions are hoisted outside the loop. -// -// CHECK: sparse_tensor.values -// CHECK: sparse_tensor.positions -// CHECK: sparse_tensor.coordinate -// CHECK: sparse_tensor.iterate -func.func @sparse_iterate(%sp : tensor) { - %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor - sparse_tensor.iterate %it1 in %l1 at (%crd) : !sparse_tensor.iter_space<#CSR, lvls = 0> { - %0 = sparse_tensor.values %sp : tensor to memref - %1 = sparse_tensor.positions %sp { level = 1 : index } : tensor to memref - %2 = sparse_tensor.coordinates %sp { level = 1 : index } : tensor to memref - "test.op"(%0, %1, %2) : (memref, memref, memref) -> () - } - - return -}