diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp index 92cb7075005a3..05f2f1e7057cc 100644 --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -85,7 +85,34 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { SmallVector newShape(1 + oldMemRefType.getRank()); newShape[0] = 2; std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); - return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({}); + + bool isDynamic = false; + for (int64_t dim : oldShape) { + if (dim == ShapedType::kDynamic) { + isDynamic = true; + break; + } + } + MemRefLayoutAttrInterface newLayout = {}; + if (auto oldLayout = oldMemRefType.getLayout()) { + if (auto stridedLayout = dyn_cast(oldLayout)) { + // Calculate leading stride + ArrayRef oldStrides = stridedLayout.getStrides(); + int64_t bufferStride = + isDynamic ? ShapedType::kDynamic : oldShape[0] * oldStrides[0]; + SmallVector newStrides; + newStrides.push_back(bufferStride); + newStrides.append(oldStrides.begin(), oldStrides.end()); + + MLIRContext *context = oldMemRefType.getContext(); + newLayout = StridedLayoutAttr::get(context, stridedLayout.getOffset(), + newStrides); + } + } + + return MemRefType::Builder(oldMemRefType) + .setShape(newShape) + .setLayout(newLayout); }; auto oldMemRefType = cast(oldMemRef.getType()); @@ -102,8 +129,14 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { } // Create and place the alloc right before the 'affine.for' operation. - Value newMemRef = bOuter.create( - forOp.getLoc(), newMemRefType, allocOperands); + Value newMemRef; + if (auto *definingOp = oldMemRef.getDefiningOp()) { + newMemRef = bOuter.create( + forOp.getLoc(), newMemRefType, allocOperands, definingOp->getAttrs()); + } else { + newMemRef = bOuter.create(forOp.getLoc(), newMemRefType, + allocOperands); + } // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0);