diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 46e5780e151f..4a2999aeaa37 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2427,7 +2427,19 @@ struct FoldTensorCastOp : public RewritePattern { // Clone op. Operation *newOp = linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands); - rewriter.replaceOp(op, newOp->getResults()); + SmallVector replacements; + replacements.reserve(newOp->getNumResults()); + for (auto result : llvm::zip(op->getResults(), newOp->getResults())) { + Value oldResult = std::get<0>(result); + Value newResult = std::get<1>(result); + if (newResult.getType() != oldResult.getType()) { + replacements.push_back(rewriter.create( + op->getLoc(), oldResult.getType(), newResult)); + } else { + replacements.push_back(newResult); + } + } + rewriter.replaceOp(op, replacements); return success(); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 536d71d89d4f..7a71f09adf63 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3305,7 +3305,11 @@ static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op, static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op, SubTensorOp newOp) { - rewriter.replaceOpWithNewOp(op, op.getType(), newOp); + Value replacement = newOp.getResult(); + if (replacement.getType() != op.getType()) + replacement = + rewriter.create(op.getLoc(), op.getType(), replacement); + rewriter.replaceOp(op, replacement); } /// Pattern to rewrite a subview op with constant arguments. @@ -3789,11 +3793,10 @@ void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result, } OpFoldResult SubTensorInsertOp::fold(ArrayRef) { - if (getSourceType() == getType() && + if (getSourceType().hasStaticShape() && getType().hasStaticShape() && + getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->source(); - if (succeeded(tensor::foldTensorCast(*this))) - return this->source(); return OpFoldResult(); } @@ -3847,9 +3850,9 @@ struct SubTensorInsertOpCastFolder final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(SubTensorInsertOp subTensorOp, + LogicalResult matchAndRewrite(SubTensorInsertOp subTensorInsertOp, PatternRewriter &rewriter) const override { - if (llvm::any_of(subTensorOp.getOperands(), [](Value operand) { + if (llvm::any_of(subTensorInsertOp.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) return failure(); @@ -3860,21 +3863,25 @@ struct SubTensorInsertOpCastFolder final return llvm::None; return castOp.source(); }; - Optional sourceCastSource = getSourceOfCastOp(subTensorOp.source()); - Optional destCastSource = getSourceOfCastOp(subTensorOp.dest()); - if (!sourceCastSource && !destCastSource && - subTensorOp.dest().getType() == subTensorOp.getResult().getType()) + Optional sourceCastSource = + getSourceOfCastOp(subTensorInsertOp.source()); + Optional destCastSource = + getSourceOfCastOp(subTensorInsertOp.dest()); + if (!sourceCastSource && !destCastSource) return failure(); - auto newOp = rewriter.create( - subTensorOp.getLoc(), - (sourceCastSource ? *sourceCastSource : subTensorOp.source()), - (destCastSource ? *destCastSource : subTensorOp.dest()), - subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(), - subTensorOp.getMixedStrides()); + Value replacement = rewriter.create( + subTensorInsertOp.getLoc(), + (sourceCastSource ? *sourceCastSource : subTensorInsertOp.source()), + (destCastSource ? *destCastSource : subTensorInsertOp.dest()), + subTensorInsertOp.getMixedOffsets(), subTensorInsertOp.getMixedSizes(), + subTensorInsertOp.getMixedStrides()); - rewriter.replaceOpWithNewOp(subTensorOp, - subTensorOp.getType(), newOp); + if (replacement.getType() != subTensorInsertOp.getType()) { + replacement = rewriter.create( + subTensorInsertOp.getLoc(), subTensorInsertOp.getType(), replacement); + } + rewriter.replaceOp(subTensorInsertOp, replacement); return success(); } }; diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 2fb5eb3086e6..f2f3a44169e8 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -767,3 +767,25 @@ func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index) // CHECK: %[[D0:.+]] = dim %[[ARG0]], %[[C4]] // CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]] // CHECK: return %[[C5]], %[[D1]] + +// ----- + +func @propogate_casts(%arg0 : tensor, %arg1 : f32, %arg2 : index, + %arg3 : index) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c21 = constant 21 : index + %c42 = constant 42 : index + %0 = linalg.init_tensor [%c21, %c42] : tensor + %1 = linalg.fill(%0, %arg1) : tensor, f32 -> tensor + %2 = dim %arg0, %c0 : tensor + %3 = dim %arg0, %c1 : tensor + %4 = subtensor_insert %arg0 into %1[%arg2, %arg3] [%2, %3] [1, 1] : tensor into tensor + return %4 : tensor +} +// CHECK-LABEL: func @propogate_casts +// CHECK: %[[INIT:.+]] = linalg.init_tensor [21, 42] +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}}) +// CHECK: %[[INSERTED:.+]] = subtensor_insert %{{.+}} into %[[FILL]] +// CHECK: %[[RESULT:.+]] = tensor.cast %[[INSERTED]] +// CHECK: return %[[RESULT]]