forked from OSchip/llvm-project
[mlir] Remove incorrect folding for SubTensorInsertOp
The SubTensorInsertOp has a requirement that dest type and result type match. Just folding the tensor.cast operation violates this and creates verification errors during canonicalization. Also fix other canonicalization methods that werent inserting casts properly. Differential Revision: https://reviews.llvm.org/D97800
This commit is contained in:
parent
75805dce5f
commit
c118fdcd59
|
@ -2427,7 +2427,19 @@ struct FoldTensorCastOp : public RewritePattern {
|
|||
// Clone op.
|
||||
Operation *newOp =
|
||||
linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
SmallVector<Value, 4> replacements;
|
||||
replacements.reserve(newOp->getNumResults());
|
||||
for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
|
||||
Value oldResult = std::get<0>(result);
|
||||
Value newResult = std::get<1>(result);
|
||||
if (newResult.getType() != oldResult.getType()) {
|
||||
replacements.push_back(rewriter.create<tensor::CastOp>(
|
||||
op->getLoc(), oldResult.getType(), newResult));
|
||||
} else {
|
||||
replacements.push_back(newResult);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, replacements);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -3305,7 +3305,11 @@ static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op,
|
|||
|
||||
static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op,
|
||||
SubTensorOp newOp) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
|
||||
Value replacement = newOp.getResult();
|
||||
if (replacement.getType() != op.getType())
|
||||
replacement =
|
||||
rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), replacement);
|
||||
rewriter.replaceOp(op, replacement);
|
||||
}
|
||||
|
||||
/// Pattern to rewrite a subview op with constant arguments.
|
||||
|
@ -3789,11 +3793,10 @@ void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
|
|||
}
|
||||
|
||||
OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
|
||||
if (getSourceType() == getType() &&
|
||||
if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
|
||||
getSourceType() == getType() &&
|
||||
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
|
||||
return this->source();
|
||||
if (succeeded(tensor::foldTensorCast(*this)))
|
||||
return this->source();
|
||||
return OpFoldResult();
|
||||
}
|
||||
|
||||
|
@ -3847,9 +3850,9 @@ struct SubTensorInsertOpCastFolder final
|
|||
: public OpRewritePattern<SubTensorInsertOp> {
|
||||
using OpRewritePattern<SubTensorInsertOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SubTensorInsertOp subTensorOp,
|
||||
LogicalResult matchAndRewrite(SubTensorInsertOp subTensorInsertOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (llvm::any_of(subTensorOp.getOperands(), [](Value operand) {
|
||||
if (llvm::any_of(subTensorInsertOp.getOperands(), [](Value operand) {
|
||||
return matchPattern(operand, m_ConstantIndex());
|
||||
}))
|
||||
return failure();
|
||||
|
@ -3860,21 +3863,25 @@ struct SubTensorInsertOpCastFolder final
|
|||
return llvm::None;
|
||||
return castOp.source();
|
||||
};
|
||||
Optional<Value> sourceCastSource = getSourceOfCastOp(subTensorOp.source());
|
||||
Optional<Value> destCastSource = getSourceOfCastOp(subTensorOp.dest());
|
||||
if (!sourceCastSource && !destCastSource &&
|
||||
subTensorOp.dest().getType() == subTensorOp.getResult().getType())
|
||||
Optional<Value> sourceCastSource =
|
||||
getSourceOfCastOp(subTensorInsertOp.source());
|
||||
Optional<Value> destCastSource =
|
||||
getSourceOfCastOp(subTensorInsertOp.dest());
|
||||
if (!sourceCastSource && !destCastSource)
|
||||
return failure();
|
||||
|
||||
auto newOp = rewriter.create<SubTensorInsertOp>(
|
||||
subTensorOp.getLoc(),
|
||||
(sourceCastSource ? *sourceCastSource : subTensorOp.source()),
|
||||
(destCastSource ? *destCastSource : subTensorOp.dest()),
|
||||
subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(),
|
||||
subTensorOp.getMixedStrides());
|
||||
Value replacement = rewriter.create<SubTensorInsertOp>(
|
||||
subTensorInsertOp.getLoc(),
|
||||
(sourceCastSource ? *sourceCastSource : subTensorInsertOp.source()),
|
||||
(destCastSource ? *destCastSource : subTensorInsertOp.dest()),
|
||||
subTensorInsertOp.getMixedOffsets(), subTensorInsertOp.getMixedSizes(),
|
||||
subTensorInsertOp.getMixedStrides());
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(subTensorOp,
|
||||
subTensorOp.getType(), newOp);
|
||||
if (replacement.getType() != subTensorInsertOp.getType()) {
|
||||
replacement = rewriter.create<tensor::CastOp>(
|
||||
subTensorInsertOp.getLoc(), subTensorInsertOp.getType(), replacement);
|
||||
}
|
||||
rewriter.replaceOp(subTensorInsertOp, replacement);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -767,3 +767,25 @@ func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
|
|||
// CHECK: %[[D0:.+]] = dim %[[ARG0]], %[[C4]]
|
||||
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
|
||||
// CHECK: return %[[C5]], %[[D1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
|
||||
%arg3 : index) -> tensor<?x?xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c21 = constant 21 : index
|
||||
%c42 = constant 42 : index
|
||||
%0 = linalg.init_tensor [%c21, %c42] : tensor<?x?xf32>
|
||||
%1 = linalg.fill(%0, %arg1) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
|
||||
%2 = dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%3 = dim %arg0, %c1 : tensor<?x?xf32>
|
||||
%4 = subtensor_insert %arg0 into %1[%arg2, %arg3] [%2, %3] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
|
||||
return %4 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @propogate_casts
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [21, 42]
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %{{.+}})
|
||||
// CHECK: %[[INSERTED:.+]] = subtensor_insert %{{.+}} into %[[FILL]]
|
||||
// CHECK: %[[RESULT:.+]] = tensor.cast %[[INSERTED]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
|
Loading…
Reference in New Issue