forked from OSchip/llvm-project
[mlir][NFC] Replace all usages of PatternMatchResult with LogicalResult
This also replaces usages of matchSuccess/matchFailure with success/failure respectively. Differential Revision: https://reviews.llvm.org/D76313
This commit is contained in:
parent
2fae7878d5
commit
3145427dd7
|
@ -247,7 +247,7 @@ struct MyConversionPattern : public ConversionPattern {
|
|||
/// The `matchAndRewrite` hooks on ConversionPatterns take an additional
|
||||
/// `operands` parameter, containing the remapped operands of the original
|
||||
/// operation.
|
||||
virtual PatternMatchResult
|
||||
virtual LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
};
|
||||
|
|
|
@ -171,8 +171,8 @@ struct ConvertTFLeakyRelu : public RewritePattern {
|
|||
ConvertTFLeakyRelu(MLIRContext *context)
|
||||
: RewritePattern("tf.LeakyRelu", 1, context) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
return matchSuccess();
|
||||
LogicalResult match(Operation *op) const override {
|
||||
return success();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
|
@ -188,12 +188,12 @@ struct ConvertTFLeakyRelu : public RewritePattern {
|
|||
ConvertTFLeakyRelu(MLIRContext *context)
|
||||
: RewritePattern("tf.LeakyRelu", 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
|
||||
op, op->getResult(0).getType(), op->getOperand(0),
|
||||
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
```
|
||||
|
|
|
@ -86,7 +86,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
/// This method is attempting to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. It is expected
|
||||
/// to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Look through the input of the current transpose.
|
||||
|
@ -96,11 +96,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
|
||||
// Input defined by another transpose? If not, no match.
|
||||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Otherwise, we have a redundant transpose. Use the rewriter.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
```
|
||||
|
|
|
@ -106,7 +106,7 @@ struct TransposeOpLowering : public mlir::ConversionPattern {
|
|||
|
||||
/// Match and rewrite the given `toy.transpose` operation, with the given
|
||||
/// operands that have been remapped from `tensor<...>` to `memref<...>`.
|
||||
mlir::PatternMatchResult
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands,
|
||||
mlir::ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -132,7 +132,7 @@ struct TransposeOpLowering : public mlir::ConversionPattern {
|
|||
SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs));
|
||||
return rewriter.create<mlir::AffineLoadOp>(loc, input, reverseIvs);
|
||||
});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
```
|
||||
|
|
|
@ -35,7 +35,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
/// This method attempts to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. The pattern is
|
||||
/// expected to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Look through the input of the current transpose.
|
||||
|
@ -45,11 +45,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
|
||||
// Input defined by another transpose? If not, no match.
|
||||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Otherwise, we have a redundant transpose. Use the rewriter.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
/// This method attempts to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. The pattern is
|
||||
/// expected to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Look through the input of the current transpose.
|
||||
|
@ -50,11 +50,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
|
||||
// Input defined by another transpose? If not, no match.
|
||||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Otherwise, we have a redundant transpose. Use the rewriter.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ struct BinaryOpLowering : public ConversionPattern {
|
|||
BinaryOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -126,7 +126,7 @@ struct BinaryOpLowering : public ConversionPattern {
|
|||
// Create the binary operation performed on the loaded values.
|
||||
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
|
||||
|
@ -139,8 +139,8 @@ using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
|
|||
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
||||
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(toy::ConstantOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(toy::ConstantOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
DenseElementsAttr constantValue = op.value();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
|
@ -189,7 +189,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
|
||||
// Replace this operation with the generated alloc.
|
||||
rewriter.replaceOp(op, alloc);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -200,16 +200,16 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
|
||||
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(toy::ReturnOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(toy::ReturnOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
// During this lowering, we expect that all function calls have been
|
||||
// inlined.
|
||||
if (op.hasOperand())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// We lower "toy.return" directly to "std.return".
|
||||
rewriter.replaceOpWithNewOp<ReturnOp>(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -221,7 +221,7 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
TransposeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -240,7 +240,7 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
|
||||
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
|
||||
});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
/// This method attempts to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. The pattern is
|
||||
/// expected to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Look through the input of the current transpose.
|
||||
|
@ -50,11 +50,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
|
||||
// Input defined by another transpose? If not, no match.
|
||||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Otherwise, we have a redundant transpose. Use the rewriter.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ struct BinaryOpLowering : public ConversionPattern {
|
|||
BinaryOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -126,7 +126,7 @@ struct BinaryOpLowering : public ConversionPattern {
|
|||
// Create the binary operation performed on the loaded values.
|
||||
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
|
||||
|
@ -139,8 +139,8 @@ using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
|
|||
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
||||
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(toy::ConstantOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(toy::ConstantOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
DenseElementsAttr constantValue = op.value();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
|
@ -189,7 +189,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
|
||||
// Replace this operation with the generated alloc.
|
||||
rewriter.replaceOp(op, alloc);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -200,16 +200,16 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
|
||||
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(toy::ReturnOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(toy::ReturnOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
// During this lowering, we expect that all function calls have been
|
||||
// inlined.
|
||||
if (op.hasOperand())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// We lower "toy.return" directly to "std.return".
|
||||
rewriter.replaceOpWithNewOp<ReturnOp>(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -221,7 +221,7 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
TransposeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -240,7 +240,7 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
|
||||
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
|
||||
});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ public:
|
|||
explicit PrintOpLowering(MLIRContext *context)
|
||||
: ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
|
||||
|
@ -91,7 +91,7 @@ public:
|
|||
|
||||
// Notify the rewriter that this operation has been removed.
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -40,7 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
/// This method attempts to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. The pattern is
|
||||
/// expected to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Look through the input of the current transpose.
|
||||
|
@ -50,11 +50,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
|
||||
// Input defined by another transpose? If not, no match.
|
||||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Otherwise, we have a redundant transpose. Use the rewriter.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ struct BinaryOpLowering : public ConversionPattern {
|
|||
BinaryOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -126,7 +126,7 @@ struct BinaryOpLowering : public ConversionPattern {
|
|||
// Create the binary operation performed on the loaded values.
|
||||
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
|
||||
|
@ -139,8 +139,8 @@ using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
|
|||
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
||||
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(toy::ConstantOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(toy::ConstantOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
DenseElementsAttr constantValue = op.value();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
|
@ -189,7 +189,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
|
||||
// Replace this operation with the generated alloc.
|
||||
rewriter.replaceOp(op, alloc);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -200,16 +200,16 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
|||
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
|
||||
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(toy::ReturnOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(toy::ReturnOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
// During this lowering, we expect that all function calls have been
|
||||
// inlined.
|
||||
if (op.hasOperand())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// We lower "toy.return" directly to "std.return".
|
||||
rewriter.replaceOpWithNewOp<ReturnOp>(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -221,7 +221,7 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
TransposeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -240,7 +240,7 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
|
||||
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
|
||||
});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ public:
|
|||
explicit PrintOpLowering(MLIRContext *context)
|
||||
: ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
|
||||
|
@ -91,7 +91,7 @@ public:
|
|||
|
||||
// Notify the rewriter that this operation has been removed.
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -58,7 +58,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
/// This method attempts to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. The pattern is
|
||||
/// expected to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Look through the input of the current transpose.
|
||||
|
@ -68,11 +68,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
|
||||
// Input defined by another transpose? If not, no match.
|
||||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Otherwise, we have a redundant transpose. Use the rewriter.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ class TileAndFuseLinalgOp<
|
|||
"if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" #
|
||||
StrJoinInt<sizes>.result # "}, {" # StrJoinInt<operandIndices>.result # "}," #
|
||||
" \"" # value # "\")))" #
|
||||
" return matchFailure();">;
|
||||
" return failure();">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg tiling patterns.
|
||||
|
@ -70,22 +70,22 @@ class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
|
|||
"if (failed(tileLinalgOpAndSetMarker($_builder, op, {" #
|
||||
StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
|
||||
StrJoinInt<permutation>.result # "})))" #
|
||||
" return matchFailure();">;
|
||||
" return failure();">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to loop patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
class LinalgOpToLoops<string OpType> : NativeCodeCall<
|
||||
"if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " #
|
||||
" return matchFailure();">;
|
||||
" return failure();">;
|
||||
|
||||
class LinalgOpToParallelLoops<string OpType> : NativeCodeCall<
|
||||
"if (failed(linalgOpToParallelLoops<" # OpType # ">($_builder, op))) " #
|
||||
" return matchFailure();">;
|
||||
" return failure();">;
|
||||
|
||||
class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
|
||||
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " #
|
||||
" return matchFailure();">;
|
||||
" return failure();">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Linalg to vector patterns precondition and DRR.
|
||||
|
|
|
@ -54,10 +54,6 @@ private:
|
|||
unsigned short representation;
|
||||
};
|
||||
|
||||
/// This is the type returned by a pattern match.
|
||||
/// TODO: Replace usages with LogicalResult directly.
|
||||
using PatternMatchResult = LogicalResult;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern class
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -85,20 +81,10 @@ public:
|
|||
|
||||
/// Attempt to match against code rooted at the specified operation,
|
||||
/// which is the same operation code as getRootKind().
|
||||
virtual PatternMatchResult match(Operation *op) const = 0;
|
||||
virtual LogicalResult match(Operation *op) const = 0;
|
||||
|
||||
virtual ~Pattern() {}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Helper methods to simplify pattern implementations
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Return a result, indicating that no match was found.
|
||||
PatternMatchResult matchFailure() const { return failure(); }
|
||||
|
||||
/// This method indicates that a match was found.
|
||||
PatternMatchResult matchSuccess() const { return success(); }
|
||||
|
||||
protected:
|
||||
/// Patterns must specify the root operation name they match against, and can
|
||||
/// also specify the benefit of the pattern matching.
|
||||
|
@ -130,22 +116,19 @@ public:
|
|||
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
|
||||
|
||||
/// Attempt to match against code rooted at the specified operation,
|
||||
/// which is the same operation code as getRootKind(). On failure, this
|
||||
/// returns a None value. On success, it returns a (possibly null)
|
||||
/// pattern-specific state wrapped in an Optional. This state is passed back
|
||||
/// into the rewrite function if this match is selected.
|
||||
PatternMatchResult match(Operation *op) const override;
|
||||
/// which is the same operation code as getRootKind().
|
||||
LogicalResult match(Operation *op) const override;
|
||||
|
||||
/// Attempt to match against code rooted at the specified operation,
|
||||
/// which is the same operation code as getRootKind(). If successful, this
|
||||
/// function will automatically perform the rewrite.
|
||||
virtual PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
virtual LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
if (succeeded(match(op))) {
|
||||
rewrite(op, rewriter);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
/// Return a list of operations that may be generated when rewriting an
|
||||
|
@ -182,11 +165,11 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
|
|||
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
|
||||
rewrite(cast<SourceOp>(op), rewriter);
|
||||
}
|
||||
PatternMatchResult match(Operation *op) const final {
|
||||
LogicalResult match(Operation *op) const final {
|
||||
return match(cast<SourceOp>(op));
|
||||
}
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
return matchAndRewrite(cast<SourceOp>(op), rewriter);
|
||||
}
|
||||
|
||||
|
@ -195,16 +178,16 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
|
|||
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
|
||||
llvm_unreachable("must override rewrite or matchAndRewrite");
|
||||
}
|
||||
virtual PatternMatchResult match(SourceOp op) const {
|
||||
virtual LogicalResult match(SourceOp op) const {
|
||||
llvm_unreachable("must override match or matchAndRewrite");
|
||||
}
|
||||
virtual PatternMatchResult matchAndRewrite(SourceOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
virtual LogicalResult matchAndRewrite(SourceOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
if (succeeded(match(op))) {
|
||||
rewrite(op, rewriter);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -235,18 +235,18 @@ public:
|
|||
}
|
||||
|
||||
/// Hook for derived classes to implement combined matching and rewriting.
|
||||
virtual PatternMatchResult
|
||||
virtual LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (failed(match(op)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
rewrite(op, operands, rewriter);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Attempt to match and rewrite the IR root at the specified operation.
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final;
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final;
|
||||
|
||||
private:
|
||||
using RewritePattern::rewrite;
|
||||
|
@ -266,7 +266,7 @@ struct OpConversionPattern : public ConversionPattern {
|
|||
ConversionPatternRewriter &rewriter) const final {
|
||||
rewrite(cast<SourceOp>(op), operands, rewriter);
|
||||
}
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
|
||||
|
@ -282,13 +282,13 @@ struct OpConversionPattern : public ConversionPattern {
|
|||
llvm_unreachable("must override matchAndRewrite or a rewrite method");
|
||||
}
|
||||
|
||||
virtual PatternMatchResult
|
||||
virtual LogicalResult
|
||||
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (failed(match(op)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
rewrite(op, operands, rewriter);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -297,15 +297,15 @@ class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
|
|||
public:
|
||||
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineMinOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineMinOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value reduced =
|
||||
lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands());
|
||||
if (!reduced)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(op, reduced);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -313,15 +313,15 @@ class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> {
|
|||
public:
|
||||
using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineMaxOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineMaxOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value reduced =
|
||||
lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands());
|
||||
if (!reduced)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(op, reduced);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -330,10 +330,10 @@ class AffineTerminatorLowering : public OpRewritePattern<AffineTerminatorOp> {
|
|||
public:
|
||||
using OpRewritePattern<AffineTerminatorOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineTerminatorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineTerminatorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<loop::YieldOp>(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -341,8 +341,8 @@ class AffineForLowering : public OpRewritePattern<AffineForOp> {
|
|||
public:
|
||||
using OpRewritePattern<AffineForOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineForOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineForOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value lowerBound = lowerAffineLowerBound(op, rewriter);
|
||||
Value upperBound = lowerAffineUpperBound(op, rewriter);
|
||||
|
@ -351,7 +351,7 @@ public:
|
|||
f.region().getBlocks().clear();
|
||||
rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end());
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -359,8 +359,8 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
|
|||
public:
|
||||
using OpRewritePattern<AffineIfOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineIfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineIfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
// Now we just have to handle the condition logic.
|
||||
|
@ -381,7 +381,7 @@ public:
|
|||
operandsRef.take_front(numDims),
|
||||
operandsRef.drop_front(numDims));
|
||||
if (!affResult)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge;
|
||||
Value cmpVal =
|
||||
rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
|
||||
|
@ -402,7 +402,7 @@ public:
|
|||
|
||||
// Ok, we're done!
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -412,15 +412,15 @@ class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
|
|||
public:
|
||||
using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineApplyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineApplyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto maybeExpandedMap =
|
||||
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
|
||||
llvm::to_vector<8>(op.getOperands()));
|
||||
if (!maybeExpandedMap)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
rewriter.replaceOp(op, *maybeExpandedMap);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -431,18 +431,18 @@ class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
|
|||
public:
|
||||
using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineLoadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineLoadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Expand affine map from 'affineLoadOp'.
|
||||
SmallVector<Value, 8> indices(op.getMapOperands());
|
||||
auto resultOperands =
|
||||
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
|
||||
if (!resultOperands)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Build std.load memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -453,20 +453,20 @@ class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
|
|||
public:
|
||||
using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffinePrefetchOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffinePrefetchOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Expand affine map from 'affinePrefetchOp'.
|
||||
SmallVector<Value, 8> indices(op.getMapOperands());
|
||||
auto resultOperands =
|
||||
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
|
||||
if (!resultOperands)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Build std.prefetch memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<PrefetchOp>(
|
||||
op, op.memref(), *resultOperands, op.isWrite(),
|
||||
op.localityHint().getZExtValue(), op.isDataCache());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -477,19 +477,19 @@ class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
|
|||
public:
|
||||
using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineStoreOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineStoreOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Expand affine map from 'affineStoreOp'.
|
||||
SmallVector<Value, 8> indices(op.getMapOperands());
|
||||
auto maybeExpandedMap =
|
||||
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
|
||||
if (!maybeExpandedMap)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Build std.store valueToStore, memref[expandedMap.results].
|
||||
rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(),
|
||||
op.getMemRef(), *maybeExpandedMap);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -500,8 +500,8 @@ class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
|
|||
public:
|
||||
using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineDmaStartOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineDmaStartOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Value, 8> operands(op.getOperands());
|
||||
auto operandsRef = llvm::makeArrayRef(operands);
|
||||
|
||||
|
@ -510,26 +510,26 @@ public:
|
|||
rewriter, op.getLoc(), op.getSrcMap(),
|
||||
operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
|
||||
if (!maybeExpandedSrcMap)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
// Expand affine map for DMA destination memref.
|
||||
auto maybeExpandedDstMap = expandAffineMap(
|
||||
rewriter, op.getLoc(), op.getDstMap(),
|
||||
operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
|
||||
if (!maybeExpandedDstMap)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
// Expand affine map for DMA tag memref.
|
||||
auto maybeExpandedTagMap = expandAffineMap(
|
||||
rewriter, op.getLoc(), op.getTagMap(),
|
||||
operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
|
||||
if (!maybeExpandedTagMap)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Build std.dma_start operation with affine map results.
|
||||
rewriter.replaceOpWithNewOp<DmaStartOp>(
|
||||
op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
|
||||
*maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
|
||||
*maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -540,19 +540,19 @@ class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
|
|||
public:
|
||||
using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineDmaWaitOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineDmaWaitOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Expand affine map for DMA tag memref.
|
||||
SmallVector<Value, 8> indices(op.getTagIndices());
|
||||
auto maybeExpandedTagMap =
|
||||
expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
|
||||
if (!maybeExpandedTagMap)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Build std.dma_wait operation with affine map results.
|
||||
rewriter.replaceOpWithNewOp<DmaWaitOp>(
|
||||
op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ public:
|
|||
indexBitwidth(getIndexBitWidth(lowering_)) {}
|
||||
|
||||
// Convert the kernel arguments to an LLVM type, preserve the rest.
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -63,7 +63,7 @@ public:
|
|||
newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
|
||||
break;
|
||||
default:
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (indexBitwidth > 32) {
|
||||
|
@ -75,7 +75,7 @@ public:
|
|||
}
|
||||
|
||||
rewriter.replaceOp(op, {newOp});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ public:
|
|||
lowering_.getDialect()->getContext(), lowering_),
|
||||
f32Func(f32Func), f64Func(f64Func) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
using LLVM::LLVMFuncOp;
|
||||
|
@ -49,13 +49,13 @@ public:
|
|||
LLVMType funcType = getFunctionType(resultType, operands);
|
||||
StringRef funcName = getFunctionName(resultType);
|
||||
if (funcName.empty())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
|
||||
auto callOp = rewriter.create<LLVM::CallOp>(
|
||||
op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
|
||||
rewriter.replaceOp(op, {callOp.getResult(0)});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -51,7 +51,7 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
|
|||
/// !llvm<"{ float, i1 }">
|
||||
/// %shfl_pred = llvm.extractvalue %shfl[1 : index] :
|
||||
/// !llvm<"{ float, i1 }">
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
|
@ -84,7 +84,7 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
|
|||
loc, predTy, shfl, rewriter.getIndexArrayAttr(1));
|
||||
|
||||
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -94,7 +94,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
|
|||
typeConverter.getDialect()->getContext(),
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
assert(operands.empty() && "func op is not expected to have operands");
|
||||
|
@ -219,7 +219,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
|
|||
signatureConversion);
|
||||
|
||||
rewriter.eraseOp(gpuFuncOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -229,11 +229,11 @@ struct GPUReturnOpLowering : public ConvertToLLVMPattern {
|
|||
typeConverter.getDialect()->getContext(),
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ class ForOpConversion final : public SPIRVOpLowering<loop::ForOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<loop::ForOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -37,7 +37,7 @@ class IfOpConversion final : public SPIRVOpLowering<loop::IfOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<loop::IfOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(loop::IfOp IfOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -47,11 +47,11 @@ class TerminatorOpConversion final : public SPIRVOpLowering<loop::YieldOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<loop::YieldOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(loop::YieldOp terminatorOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.eraseOp(terminatorOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -62,7 +62,7 @@ class LaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<SourceOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -75,7 +75,7 @@ class WorkGroupSizeConversion : public SPIRVOpLowering<gpu::BlockDimOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<gpu::BlockDimOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::BlockDimOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -85,7 +85,7 @@ class GPUFuncOpConversion final : public SPIRVOpLowering<gpu::GPUFuncOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<gpu::GPUFuncOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
|
||||
|
@ -98,7 +98,7 @@ class GPUModuleConversion final : public SPIRVOpLowering<gpu::GPUModuleOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<gpu::GPUModuleOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -109,7 +109,7 @@ class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<gpu::ReturnOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -120,7 +120,7 @@ public:
|
|||
// loop::ForOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// loop::ForOp can be lowered to the structured control flow represented by
|
||||
|
@ -186,14 +186,14 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
|
|||
rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
|
||||
|
||||
rewriter.eraseOp(forOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// loop::IfOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
IfOpConversion::matchAndRewrite(loop::IfOp ifOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// When lowering `loop::IfOp` we explicitly create a selection header block
|
||||
|
@ -238,7 +238,7 @@ IfOpConversion::matchAndRewrite(loop::IfOp ifOp, ArrayRef<Value> operands,
|
|||
elseBlock, ArrayRef<Value>());
|
||||
|
||||
rewriter.eraseOp(ifOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -261,36 +261,36 @@ static Optional<int32_t> getLaunchConfigIndex(Operation *op) {
|
|||
}
|
||||
|
||||
template <typename SourceOp, spirv::BuiltIn builtin>
|
||||
PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
|
||||
LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
|
||||
SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto index = getLaunchConfigIndex(op);
|
||||
if (!index)
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
|
||||
// SPIR-V invocation builtin variables are a vector of type <3xi32>
|
||||
auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
|
||||
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
|
||||
op, rewriter.getIntegerType(32), spirvBuiltin,
|
||||
rewriter.getI32ArrayAttr({index.getValue()}));
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternMatchResult WorkGroupSizeConversion::matchAndRewrite(
|
||||
LogicalResult WorkGroupSizeConversion::matchAndRewrite(
|
||||
gpu::BlockDimOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto index = getLaunchConfigIndex(op);
|
||||
if (!index)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
|
||||
auto val = workGroupSizeAttr.getValue<int32_t>(index.getValue());
|
||||
auto convertedType = typeConverter.convertType(op.getResult().getType());
|
||||
if (!convertedType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
|
||||
op, convertedType, IntegerAttr::get(convertedType, val));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -343,11 +343,11 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter,
|
|||
return newFuncOp;
|
||||
}
|
||||
|
||||
PatternMatchResult GPUFuncOpConversion::matchAndRewrite(
|
||||
LogicalResult GPUFuncOpConversion::matchAndRewrite(
|
||||
gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (!gpu::GPUDialect::isKernel(funcOp))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
|
||||
for (auto argNum : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
|
||||
|
@ -358,22 +358,22 @@ PatternMatchResult GPUFuncOpConversion::matchAndRewrite(
|
|||
auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
|
||||
if (!entryPointAttr) {
|
||||
funcOp.emitRemark("match failure: missing 'spv.entry_point_abi' attribute");
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
spirv::FuncOp newFuncOp = lowerAsEntryFunction(
|
||||
funcOp, typeConverter, rewriter, entryPointAttr, argABI);
|
||||
if (!newFuncOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
newFuncOp.removeAttr(Identifier::get(gpu::GPUDialect::getKernelFuncAttrName(),
|
||||
rewriter.getContext()));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModuleOp with gpu.module.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult GPUModuleConversion::matchAndRewrite(
|
||||
LogicalResult GPUModuleConversion::matchAndRewrite(
|
||||
gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto spvModule = rewriter.create<spirv::ModuleOp>(
|
||||
|
@ -389,21 +389,21 @@ PatternMatchResult GPUModuleConversion::matchAndRewrite(
|
|||
// legalized later.
|
||||
spvModuleRegion.back().erase();
|
||||
rewriter.eraseOp(moduleOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GPU return inside kernel functions to SPIR-V return.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult GPUReturnOpConversion::matchAndRewrite(
|
||||
LogicalResult GPUReturnOpConversion::matchAndRewrite(
|
||||
gpu::ReturnOp returnOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (!operands.empty())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -130,7 +130,7 @@ public:
|
|||
explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||
: ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto rangeOp = cast<RangeOp>(op);
|
||||
|
@ -146,7 +146,7 @@ public:
|
|||
desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
|
||||
desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
|
||||
rewriter.replaceOp(op, desc);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -160,14 +160,14 @@ public:
|
|||
: ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
|
||||
lowering_) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto reshapeOp = cast<ReshapeOp>(op);
|
||||
MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>();
|
||||
|
||||
if (!dstType.hasStaticShape())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
|
@ -175,7 +175,7 @@ public:
|
|||
if (failed(res) || llvm::any_of(strides, [](int64_t val) {
|
||||
return ShapedType::isDynamicStrideOrOffset(val);
|
||||
}))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
ReshapeOpOperandAdaptor adaptor(operands);
|
||||
|
@ -189,7 +189,7 @@ public:
|
|||
for (auto en : llvm::enumerate(strides))
|
||||
desc.setConstantStride(en.index(), en.value());
|
||||
rewriter.replaceOp(op, {desc});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -204,7 +204,7 @@ public:
|
|||
explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||
: ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
|
@ -247,7 +247,7 @@ public:
|
|||
|
||||
// Corner case, no sizes or strides: early return the descriptor.
|
||||
if (sliceOp.getShapedType().getRank() == 0)
|
||||
return rewriter.replaceOp(op, {desc}), matchSuccess();
|
||||
return rewriter.replaceOp(op, {desc}), success();
|
||||
|
||||
Value zero = llvm_constant(
|
||||
int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
||||
|
@ -279,7 +279,7 @@ public:
|
|||
}
|
||||
|
||||
rewriter.replaceOp(op, {desc});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -297,7 +297,7 @@ public:
|
|||
: ConvertToLLVMPattern(TransposeOp::getOperationName(), context,
|
||||
lowering_) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Initialize the common boilerplate and alloca at the top of the FuncOp.
|
||||
|
@ -308,7 +308,7 @@ public:
|
|||
auto transposeOp = cast<TransposeOp>(op);
|
||||
// No permutation, early exit.
|
||||
if (transposeOp.permutation().isIdentity())
|
||||
return rewriter.replaceOp(op, {baseDesc}), matchSuccess();
|
||||
return rewriter.replaceOp(op, {baseDesc}), success();
|
||||
|
||||
BaseViewConversionHelper desc(
|
||||
typeConverter.convertType(transposeOp.getShapedType()));
|
||||
|
@ -330,7 +330,7 @@ public:
|
|||
}
|
||||
|
||||
rewriter.replaceOp(op, {desc});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -340,11 +340,11 @@ public:
|
|||
explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||
: ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -416,15 +416,15 @@ class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
|
|||
public:
|
||||
using OpRewritePattern<LinalgOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(LinalgOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(LinalgOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
|
||||
if (!libraryCallName)
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -434,22 +434,22 @@ template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
|
|||
public:
|
||||
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(CopyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(CopyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto inputPerm = op.inputPermutation();
|
||||
if (inputPerm.hasValue() && !inputPerm->isIdentity())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
auto outputPerm = op.outputPermutation();
|
||||
if (outputPerm.hasValue() && !outputPerm->isIdentity())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
|
||||
if (!libraryCallName)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -460,12 +460,12 @@ class LinalgOpConversion<IndexedGenericOp>
|
|||
public:
|
||||
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(IndexedGenericOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(IndexedGenericOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto libraryCallName =
|
||||
getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
|
||||
if (!libraryCallName)
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
|
||||
// TODO(pifon, ntv): Use induction variables values instead of zeros, when
|
||||
// IndexedGenericOp is tiled.
|
||||
|
@ -483,7 +483,7 @@ public:
|
|||
}
|
||||
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
|
||||
ArrayRef<Type>{}, operands);
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -495,8 +495,8 @@ class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
|
|||
public:
|
||||
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(CopyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(CopyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value in = op.input(), out = op.output();
|
||||
|
||||
// If either inputPerm or outputPerm are non-identities, insert transposes.
|
||||
|
@ -511,10 +511,10 @@ public:
|
|||
|
||||
// If nothing was transposed, fail and let the conversion kick in.
|
||||
if (in == op.input() && out == op.output())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ public:
|
|||
static Optional<linalg::RegionMatcher::BinaryOpKind>
|
||||
matchAsPerformingReduction(linalg::GenericOp genericOp);
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(linalg::GenericOp genericOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -109,7 +109,7 @@ SingleWorkgroupReduction::matchAsPerformingReduction(
|
|||
return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);
|
||||
}
|
||||
|
||||
PatternMatchResult SingleWorkgroupReduction::matchAndRewrite(
|
||||
LogicalResult SingleWorkgroupReduction::matchAndRewrite(
|
||||
linalg::GenericOp genericOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Operation *op = genericOp.getOperation();
|
||||
|
@ -118,19 +118,19 @@ PatternMatchResult SingleWorkgroupReduction::matchAndRewrite(
|
|||
|
||||
auto binaryOpKind = matchAsPerformingReduction(genericOp);
|
||||
if (!binaryOpKind)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Query the shader interface for local workgroup size to make sure the
|
||||
// invocation configuration fits with the input memref's shape.
|
||||
DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp);
|
||||
if (!localSize)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1),
|
||||
[](const APInt &size) { return !size.isOneValue(); }))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// TODO(antiagainst): Query the target environment to make sure the current
|
||||
// workload fits in a local workgroup.
|
||||
|
@ -195,7 +195,7 @@ PatternMatchResult SingleWorkgroupReduction::matchAndRewrite(
|
|||
spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, &rewriter);
|
||||
|
||||
rewriter.eraseOp(genericOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -98,8 +98,8 @@ struct LoopToStandardPass : public OperationPass<LoopToStandardPass> {
|
|||
struct ForLowering : public OpRewritePattern<ForOp> {
|
||||
using OpRewritePattern<ForOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(ForOp forOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(ForOp forOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
// Create a CFG subgraph for the loop.if operation (including its "then" and
|
||||
|
@ -147,20 +147,20 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
|||
struct IfLowering : public OpRewritePattern<IfOp> {
|
||||
using OpRewritePattern<IfOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(IfOp ifOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(IfOp ifOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct ParallelLowering : public OpRewritePattern<mlir::loop::ParallelOp> {
|
||||
using OpRewritePattern<mlir::loop::ParallelOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(mlir::loop::ParallelOp parallelOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(mlir::loop::ParallelOp parallelOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
PatternMatchResult
|
||||
ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
|
||||
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
Location loc = forOp.getLoc();
|
||||
|
||||
// Start by splitting the block containing the 'loop.for' into two parts.
|
||||
|
@ -189,7 +189,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
|
|||
auto step = forOp.step();
|
||||
auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
|
||||
if (!stepped)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
SmallVector<Value, 8> loopCarried;
|
||||
loopCarried.push_back(stepped);
|
||||
|
@ -202,7 +202,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
|
|||
Value lowerBound = forOp.lowerBound();
|
||||
Value upperBound = forOp.upperBound();
|
||||
if (!lowerBound || !upperBound)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// The initial values of loop-carried values is obtained from the operands
|
||||
// of the loop operation.
|
||||
|
@ -222,11 +222,11 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
|
|||
// The result of the loop operation is the values of the condition block
|
||||
// arguments except the induction variable on the last iteration.
|
||||
rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternMatchResult
|
||||
IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
|
||||
LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto loc = ifOp.getLoc();
|
||||
|
||||
// Start by splitting the block containing the 'loop.if' into two parts.
|
||||
|
@ -265,10 +265,10 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
|
|||
|
||||
// Ok, we're done!
|
||||
rewriter.eraseOp(ifOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
Location loc = parallelOp.getLoc();
|
||||
|
@ -344,7 +344,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
|
|||
|
||||
rewriter.replaceOp(parallelOp, loopResults);
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::populateLoopToStdConversionPatterns(
|
||||
|
|
|
@ -497,8 +497,8 @@ namespace {
|
|||
struct ParallelToGpuLaunchLowering : public OpRewritePattern<ParallelOp> {
|
||||
using OpRewritePattern<ParallelOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(ParallelOp parallelOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(ParallelOp parallelOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct MappingAnnotation {
|
||||
|
@ -742,7 +742,7 @@ static LogicalResult processParallelLoop(ParallelOp parallelOp,
|
|||
/// the actual loop bound. This only works if an static upper bound for the
|
||||
/// dynamic loop bound can be defived, currently via analyzing `affine.min`
|
||||
/// operations.
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
// Create a launch operation. We start with bound one for all grid/block
|
||||
|
@ -761,7 +761,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
|
|||
SmallVector<Operation *, 16> worklist;
|
||||
if (failed(processParallelLoop(parallelOp, launchOp, cloningMap, worklist,
|
||||
launchBounds, rewriter)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Whether we have seen any side-effects. Reset when leaving an inner scope.
|
||||
bool seenSideeffects = false;
|
||||
|
@ -778,13 +778,13 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
|
|||
// Before entering a nested scope, make sure there have been no
|
||||
// sideeffects until now.
|
||||
if (seenSideeffects)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
// A nested loop.parallel needs insertion of code to compute indices.
|
||||
// Insert that now. This will also update the worklist with the loops
|
||||
// body.
|
||||
if (failed(processParallelLoop(nestedParallel, launchOp, cloningMap,
|
||||
worklist, launchBounds, rewriter)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
} else if (op == launchOp.getOperation()) {
|
||||
// Found our sentinel value. We have finished the operations from one
|
||||
// nesting level, pop one level back up.
|
||||
|
@ -802,7 +802,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
|
|||
clone->getNumRegions() != 0;
|
||||
// If we are no longer in the innermost scope, sideeffects are disallowed.
|
||||
if (seenSideeffects && leftNestingScope)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -812,7 +812,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
|
|||
launchOp.setOperand(std::get<0>(bound), std::get<1>(bound));
|
||||
|
||||
rewriter.eraseOp(parallelOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns,
|
||||
|
|
|
@ -946,7 +946,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
|||
bool emitCWrappers)
|
||||
: FuncOpConversionBase(dialect, converter), emitWrappers(emitCWrappers) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
|
@ -962,7 +962,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
|||
}
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -976,7 +976,7 @@ private:
|
|||
struct BarePtrFuncOpConversion : public FuncOpConversionBase {
|
||||
using FuncOpConversionBase::FuncOpConversionBase;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
|
@ -990,7 +990,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
|
|||
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
||||
if (newFuncOp.getBody().empty()) {
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Promote bare pointers from MemRef arguments to a MemRef descriptor struct
|
||||
|
@ -1017,7 +1017,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
|
|||
}
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1109,7 +1109,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
|
||||
// Convert the type of the result to an LLVM type, pass operands as is,
|
||||
// preserve attributes.
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
unsigned numResults = op->getNumResults();
|
||||
|
@ -1119,7 +1119,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
packedType =
|
||||
this->typeConverter.packFunctionResults(op->getResultTypes());
|
||||
if (!packedType)
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
|
||||
|
@ -1127,10 +1127,10 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
|
||||
// If the operation produced 0 or 1 result, return them immediately.
|
||||
if (numResults == 0)
|
||||
return rewriter.eraseOp(op), this->matchSuccess();
|
||||
return rewriter.eraseOp(op), success();
|
||||
if (numResults == 1)
|
||||
return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
|
||||
this->matchSuccess();
|
||||
success();
|
||||
|
||||
// Otherwise, it had been converted to an operation producing a structure.
|
||||
// Extract individual results from the structure and return them as list.
|
||||
|
@ -1143,7 +1143,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
rewriter.getI64ArrayAttr(i)));
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1207,7 +1207,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
|
||||
// Convert the type of the result to an LLVM type, pass operands as is,
|
||||
// preserve attributes.
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ValidateOpCount<SourceOp, OpCount>();
|
||||
|
@ -1221,7 +1221,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
// Cannot convert ops if their operands are not of LLVM type.
|
||||
for (Value operand : operands) {
|
||||
if (!operand || !operand.getType().isa<LLVM::LLVMType>())
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
|
||||
|
@ -1230,7 +1230,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
auto newOp = rewriter.create<TargetOp>(
|
||||
op->getLoc(), operands[0].getType(), operands, op->getAttrs());
|
||||
rewriter.replaceOp(op, newOp.getResult());
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
if (succeeded(HandleMultidimensionalVectors(
|
||||
|
@ -1240,8 +1240,8 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
operands, op->getAttrs());
|
||||
},
|
||||
rewriter)))
|
||||
return this->matchSuccess();
|
||||
return this->matchFailure();
|
||||
return success();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1381,24 +1381,24 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
: LLVMLegalizationPattern<AllocOp>(dialect_, converter),
|
||||
useAlloca(useAlloca) {}
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
LogicalResult match(Operation *op) const override {
|
||||
MemRefType type = cast<AllocOp>(op).getType();
|
||||
if (isSupportedMemRefType(type))
|
||||
return matchSuccess();
|
||||
return success();
|
||||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = getStridesAndOffset(type, strides, offset);
|
||||
if (failed(successStrides))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Dynamic strides are ok if they can be deduced from dynamic sizes (which
|
||||
// is guaranteed when succeeded(successStrides)). Dynamic offset however can
|
||||
// never be alloc'ed.
|
||||
if (offset == MemRefType::getDynamicStrideOrOffset())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -1574,7 +1574,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
|
|||
using Super = CallOpInterfaceLowering<CallOpType>;
|
||||
using Base = LLVMLegalizationPattern<CallOpType>;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
OperandAdaptor<CallOpType> transformed(operands);
|
||||
|
@ -1595,7 +1595,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
|
|||
if (numResults != 0) {
|
||||
if (!(packedResult =
|
||||
this->typeConverter.packFunctionResults(resultTypes)))
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto promoted = this->typeConverter.promoteMemRefDescriptors(
|
||||
|
@ -1606,7 +1606,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
|
|||
// If < 2 results, packing did not do anything and we can just return.
|
||||
if (numResults < 2) {
|
||||
rewriter.replaceOp(op, newOp.getResults());
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Otherwise, it had been converted to an operation producing a structure.
|
||||
|
@ -1624,7 +1624,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
|
|||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1647,11 +1647,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
: LLVMLegalizationPattern<DeallocOp>(dialect_, converter),
|
||||
useAlloca(useAlloca) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (useAlloca)
|
||||
return rewriter.eraseOp(op), matchSuccess();
|
||||
return rewriter.eraseOp(op), success();
|
||||
|
||||
assert(operands.size() == 1 && "dealloc takes one operand");
|
||||
OperandAdaptor<DeallocOp> transformed(operands);
|
||||
|
@ -1673,7 +1673,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
memref.allocatedPtr(rewriter, op->getLoc()));
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
|
||||
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
bool useAlloca;
|
||||
|
@ -1683,7 +1683,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
|
||||
using LLVMLegalizationPattern<RsqrtOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
OperandAdaptor<RsqrtOp> transformed(operands);
|
||||
|
@ -1691,7 +1691,7 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
|
|||
transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
|
||||
|
||||
if (!operandType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto resultType = *op->result_type_begin();
|
||||
|
@ -1709,12 +1709,12 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
|
|||
}
|
||||
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
|
||||
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
auto vectorType = resultType.dyn_cast<VectorType>();
|
||||
if (!vectorType)
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
|
||||
if (succeeded(HandleMultidimensionalVectors(
|
||||
op, operands, typeConverter,
|
||||
|
@ -1732,8 +1732,8 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
|
|||
sqrt);
|
||||
},
|
||||
rewriter)))
|
||||
return this->matchSuccess();
|
||||
return this->matchFailure();
|
||||
return success();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1741,7 +1741,7 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
|
|||
struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
|
||||
using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
|
@ -1753,7 +1753,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
|
|||
transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
|
||||
|
||||
if (!operandType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
std::string functionName;
|
||||
if (operandType.isFloatTy())
|
||||
|
@ -1761,7 +1761,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
|
|||
else if (operandType.isDoubleTy())
|
||||
functionName = "tanh";
|
||||
else
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Get a reference to the tanh function, inserting it if necessary.
|
||||
Operation *tanhFunc =
|
||||
|
@ -1783,14 +1783,14 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
|
|||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
|
||||
op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc),
|
||||
transformed.operand());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
||||
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
LogicalResult match(Operation *op) const override {
|
||||
auto memRefCastOp = cast<MemRefCastOp>(op);
|
||||
Type srcType = memRefCastOp.getOperand().getType();
|
||||
Type dstType = memRefCastOp.getType();
|
||||
|
@ -1801,8 +1801,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
MemRefType targetType = memRefCastOp.getType().cast<MemRefType>();
|
||||
return (isSupportedMemRefType(targetType) &&
|
||||
isSupportedMemRefType(sourceType))
|
||||
? matchSuccess()
|
||||
: matchFailure();
|
||||
? success()
|
||||
: failure();
|
||||
}
|
||||
|
||||
// At least one of the operands is unranked type
|
||||
|
@ -1812,8 +1812,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
// Unranked to unranked cast is disallowed
|
||||
return !(srcType.isa<UnrankedMemRefType>() &&
|
||||
dstType.isa<UnrankedMemRefType>())
|
||||
? matchSuccess()
|
||||
: matchFailure();
|
||||
? success()
|
||||
: failure();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, ArrayRef<Value> operands,
|
||||
|
@ -1886,17 +1886,17 @@ struct DialectCastOpLowering
|
|||
: public LLVMLegalizationPattern<LLVM::DialectCastOp> {
|
||||
using LLVMLegalizationPattern<LLVM::DialectCastOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto castOp = cast<LLVM::DialectCastOp>(op);
|
||||
OperandAdaptor<LLVM::DialectCastOp> transformed(operands);
|
||||
if (transformed.in().getType() !=
|
||||
typeConverter.convertType(castOp.getType())) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
rewriter.replaceOp(op, transformed.in());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1905,7 +1905,7 @@ struct DialectCastOpLowering
|
|||
struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
|
||||
using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dimOp = cast<DimOp>(op);
|
||||
|
@ -1922,7 +1922,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
|
|||
// Use constant for static size.
|
||||
rewriter.replaceOp(
|
||||
op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1934,10 +1934,9 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
|||
using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern;
|
||||
using Base = LoadStoreOpLowering<Derived>;
|
||||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
LogicalResult match(Operation *op) const override {
|
||||
MemRefType type = cast<Derived>(op).getMemRefType();
|
||||
return isSupportedMemRefType(type) ? this->matchSuccess()
|
||||
: this->matchFailure();
|
||||
return isSupportedMemRefType(type) ? success() : failure();
|
||||
}
|
||||
|
||||
// Given subscript indices and array sizes in row-major order,
|
||||
|
@ -2010,7 +2009,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
|||
struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
|
||||
using Base::Base;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loadOp = cast<LoadOp>(op);
|
||||
|
@ -2020,7 +2019,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
|
|||
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
|
||||
transformed.indices(), rewriter, getModule());
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2029,7 +2028,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
|
|||
struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
|
||||
using Base::Base;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto type = cast<StoreOp>(op).getMemRefType();
|
||||
|
@ -2039,7 +2038,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
|
|||
transformed.indices(), rewriter, getModule());
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
|
||||
dataPtr);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2048,7 +2047,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
|
|||
struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
|
||||
using Base::Base;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto prefetchOp = cast<PrefetchOp>(op);
|
||||
|
@ -2072,7 +2071,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
|
|||
|
||||
rewriter.replaceOpWithNewOp<LLVM::Prefetch>(op, dataPtr, isWrite,
|
||||
localityHint, isData);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2083,7 +2082,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
|
|||
struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
|
||||
using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
IndexCastOpOperandAdaptor transformed(operands);
|
||||
|
@ -2104,7 +2103,7 @@ struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
|
|||
else
|
||||
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
|
||||
transformed.in());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2118,7 +2117,7 @@ static LLVMPredType convertCmpPredicate(StdPredType pred) {
|
|||
struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
|
||||
using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto cmpiOp = cast<CmpIOp>(op);
|
||||
|
@ -2130,14 +2129,14 @@ struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
|
|||
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
|
||||
transformed.lhs(), transformed.rhs());
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
|
||||
using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto cmpfOp = cast<CmpFOp>(op);
|
||||
|
@ -2149,7 +2148,7 @@ struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
|
|||
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
|
||||
transformed.lhs(), transformed.rhs());
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2189,12 +2188,12 @@ struct OneToOneLLVMTerminatorLowering
|
|||
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
|
||||
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<TargetOp>(op, operands, op->getSuccessors(),
|
||||
op->getAttrs());
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2207,7 +2206,7 @@ struct OneToOneLLVMTerminatorLowering
|
|||
struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
|
||||
using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
unsigned numArguments = op->getNumOperands();
|
||||
|
@ -2216,12 +2215,12 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
|
|||
if (numArguments == 0) {
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
|
||||
op, ArrayRef<Type>(), ArrayRef<Value>(), op->getAttrs());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
if (numArguments == 1) {
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
|
||||
op, ArrayRef<Type>(), operands.front(), op->getAttrs());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Otherwise, we need to pack the arguments into an LLVM struct type before
|
||||
|
@ -2237,7 +2236,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
|
|||
}
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, ArrayRef<Type>(), packed,
|
||||
op->getAttrs());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2256,13 +2255,13 @@ struct CondBranchOpLowering
|
|||
struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
|
||||
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto splatOp = cast<SplatOp>(op);
|
||||
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
|
||||
if (!resultType || resultType.getRank() != 1)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// First insert it into an undef vector so we can shuffle it.
|
||||
auto vectorType = typeConverter.convertType(splatOp.getType());
|
||||
|
@ -2280,7 +2279,7 @@ struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
|
|||
// Shuffle the value across the desired number of elements.
|
||||
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
|
||||
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2290,14 +2289,14 @@ struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
|
|||
struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
|
||||
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto splatOp = cast<SplatOp>(op);
|
||||
OperandAdaptor<SplatOp> adaptor(operands);
|
||||
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
|
||||
if (!resultType || resultType.getRank() == 1)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// First insert it into an undef vector so we can shuffle it.
|
||||
auto loc = op->getLoc();
|
||||
|
@ -2305,7 +2304,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
|
|||
auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
|
||||
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
|
||||
if (!llvmArrayTy || !llvmVectorTy)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Construct returned value.
|
||||
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
|
||||
|
@ -2332,7 +2331,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
|
|||
position);
|
||||
});
|
||||
rewriter.replaceOp(op, desc);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2344,7 +2343,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
|
|||
struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
|
||||
using LLVMLegalizationPattern<SubViewOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -2376,7 +2375,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
|
|||
auto targetDescTy = typeConverter.convertType(viewMemRefType)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!sourceElementTy || !targetDescTy)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Currently, only rank > 0 and full or no operands are supported. Fail to
|
||||
// convert otherwise.
|
||||
|
@ -2385,22 +2384,22 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
|
|||
(!dynamicOffsets.empty() && rank != dynamicOffsets.size()) ||
|
||||
(!dynamicSizes.empty() && rank != dynamicSizes.size()) ||
|
||||
(!dynamicStrides.empty() && rank != dynamicStrides.size()))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
|
||||
if (failed(successStrides))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Fail to convert if neither a dynamic nor static offset is available.
|
||||
if (dynamicOffsets.empty() &&
|
||||
offset == MemRefType::getDynamicStrideOrOffset())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Create the descriptor.
|
||||
if (!operands.front().getType().isa<LLVM::LLVMType>())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
MemRefDescriptor sourceMemRef(operands.front());
|
||||
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
|
||||
|
||||
|
@ -2460,7 +2459,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
|
|||
}
|
||||
|
||||
rewriter.replaceOp(op, {targetMemRef});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2505,7 +2504,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
|
|||
return createIndexConstant(rewriter, loc, 1);
|
||||
}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -2520,14 +2519,13 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
|
|||
typeConverter.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
|
||||
if (!targetDescTy)
|
||||
return op->emitWarning("Target descriptor type not converted to LLVM"),
|
||||
matchFailure();
|
||||
failure();
|
||||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
|
||||
if (failed(successStrides))
|
||||
return op->emitWarning("cannot cast to non-strided shape"),
|
||||
matchFailure();
|
||||
return op->emitWarning("cannot cast to non-strided shape"), failure();
|
||||
|
||||
// Create the descriptor.
|
||||
MemRefDescriptor sourceMemRef(adaptor.source());
|
||||
|
@ -2560,12 +2558,11 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
|
|||
|
||||
// Early exit for 0-D corner case.
|
||||
if (viewMemRefType.getRank() == 0)
|
||||
return rewriter.replaceOp(op, {targetMemRef}), matchSuccess();
|
||||
return rewriter.replaceOp(op, {targetMemRef}), success();
|
||||
|
||||
// Fields 4 and 5: Update sizes and strides.
|
||||
if (strides.back() != 1)
|
||||
return op->emitWarning("cannot cast to non-contiguous shape"),
|
||||
matchFailure();
|
||||
return op->emitWarning("cannot cast to non-contiguous shape"), failure();
|
||||
Value stride = nullptr, nextSize = nullptr;
|
||||
// Drop the dynamic stride from the operand list, if present.
|
||||
ArrayRef<Value> sizeOperands(sizeAndOffsetOperands);
|
||||
|
@ -2583,7 +2580,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
|
|||
}
|
||||
|
||||
rewriter.replaceOp(op, {targetMemRef});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2591,7 +2588,7 @@ struct AssumeAlignmentOpLowering
|
|||
: public LLVMLegalizationPattern<AssumeAlignmentOp> {
|
||||
using LLVMLegalizationPattern<AssumeAlignmentOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
OperandAdaptor<AssumeAlignmentOp> transformed(operands);
|
||||
|
@ -2622,7 +2619,7 @@ struct AssumeAlignmentOpLowering
|
|||
rewriter.create<LLVM::AndOp>(op->getLoc(), ptrValue, mask), zero));
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2657,13 +2654,13 @@ namespace {
|
|||
struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
||||
using Base::Base;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto atomicOp = cast<AtomicRMWOp>(op);
|
||||
auto maybeKind = matchSimpleAtomicOp(atomicOp);
|
||||
if (!maybeKind)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
OperandAdaptor<AtomicRMWOp> adaptor(operands);
|
||||
auto resultType = adaptor.value().getType();
|
||||
auto memRefType = atomicOp.getMemRefType();
|
||||
|
@ -2672,7 +2669,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
|||
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
|
||||
op, resultType, *maybeKind, dataPtr, adaptor.value(),
|
||||
LLVM::AtomicOrdering::acq_rel);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2706,13 +2703,13 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
|||
struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
||||
using Base::Base;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto atomicOp = cast<AtomicRMWOp>(op);
|
||||
auto maybeKind = matchSimpleAtomicOp(atomicOp);
|
||||
if (maybeKind)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
LLVM::FCmpPredicate predicate;
|
||||
switch (atomicOp.kind()) {
|
||||
|
@ -2723,7 +2720,7 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
|||
predicate = LLVM::FCmpPredicate::olt;
|
||||
break;
|
||||
default:
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
OperandAdaptor<AtomicRMWOp> adaptor(operands);
|
||||
|
@ -2779,7 +2776,7 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
|
|||
// The 'result' of the atomic_rmw op is the newly loaded value.
|
||||
rewriter.replaceOp(op, {newLoaded});
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class ConstantCompositeOpConversion final : public SPIRVOpLowering<ConstantOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(ConstantOp constCompositeOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -45,7 +45,7 @@ class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -55,7 +55,7 @@ class CmpFOpConversion final : public SPIRVOpLowering<CmpFOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -65,7 +65,7 @@ class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -81,14 +81,14 @@ class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto resultType =
|
||||
this->typeConverter.convertType(operation.getResult().getType());
|
||||
rewriter.template replaceOpWithNewOp<SPIRVOp>(
|
||||
operation, resultType, operands, ArrayRef<NamedAttribute>());
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -100,7 +100,7 @@ class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -111,7 +111,7 @@ class ReturnOpConversion final : public SPIRVOpLowering<ReturnOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -121,7 +121,7 @@ public:
|
|||
class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
|
||||
public:
|
||||
using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -134,7 +134,7 @@ class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -145,22 +145,22 @@ public:
|
|||
// ConstantOp with composite type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite(
|
||||
LogicalResult ConstantCompositeOpConversion::matchAndRewrite(
|
||||
ConstantOp constCompositeOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto compositeType =
|
||||
constCompositeOp.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!compositeType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto spirvCompositeType = typeConverter.convertType(compositeType);
|
||||
if (!spirvCompositeType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto linearizedElements =
|
||||
constCompositeOp.value().dyn_cast<DenseElementsAttr>();
|
||||
if (!linearizedElements)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// If composite type has rank greater than one, then perform linearization.
|
||||
if (compositeType.getRank() > 1) {
|
||||
|
@ -171,24 +171,24 @@ PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite(
|
|||
|
||||
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
|
||||
constCompositeOp, spirvCompositeType, linearizedElements);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantOp with index type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
|
||||
LogicalResult ConstantIndexOpConversion::matchAndRewrite(
|
||||
ConstantOp constIndexOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (!constIndexOp.getResult().getType().isa<IndexType>()) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
// The attribute has index type which is not directly supported in
|
||||
// SPIR-V. Get the integer value and create a new IntegerAttr.
|
||||
auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
|
||||
if (!constAttr) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Use the bitwidth set in the value attribute to decide the result type
|
||||
|
@ -197,7 +197,7 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
|
|||
auto constVal = constAttr.getValue();
|
||||
auto constValType = constAttr.getType().dyn_cast<IndexType>();
|
||||
if (!constValType) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
auto spirvConstType =
|
||||
typeConverter.convertType(constIndexOp.getResult().getType());
|
||||
|
@ -205,14 +205,14 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
|
|||
rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
|
||||
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
|
||||
spirvConstVal);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CmpFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
CmpFOpOperandAdaptor cmpFOpOperands(operands);
|
||||
|
@ -223,7 +223,7 @@ CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
|||
rewriter.replaceOpWithNewOp<spirvOp>(cmpFOp, cmpFOp.getResult().getType(), \
|
||||
cmpFOpOperands.lhs(), \
|
||||
cmpFOpOperands.rhs()); \
|
||||
return matchSuccess();
|
||||
return success();
|
||||
|
||||
// Ordered.
|
||||
DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
|
||||
|
@ -245,14 +245,14 @@ CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
|||
default:
|
||||
break;
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CmpIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
CmpIOpOperandAdaptor cmpIOpOperands(operands);
|
||||
|
@ -263,7 +263,7 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
|||
rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
|
||||
cmpIOpOperands.lhs(), \
|
||||
cmpIOpOperands.rhs()); \
|
||||
return matchSuccess();
|
||||
return success();
|
||||
|
||||
DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
|
||||
DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
|
||||
|
@ -278,14 +278,14 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
|||
|
||||
#undef DISPATCH
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
LoadOpOperandAdaptor loadOperands(operands);
|
||||
|
@ -293,42 +293,42 @@ LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
|||
typeConverter, loadOp.memref().getType().cast<MemRefType>(),
|
||||
loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
|
||||
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReturnOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (returnOp.getNumOperands()) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SelectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
SelectOpOperandAdaptor selectOperands(operands);
|
||||
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
|
||||
selectOperands.true_value(),
|
||||
selectOperands.false_value());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
StoreOpOperandAdaptor storeOperands(operands);
|
||||
|
@ -338,7 +338,7 @@ StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
|
|||
rewriter);
|
||||
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
|
||||
storeOperands.value());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -26,8 +26,8 @@ class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> {
|
|||
public:
|
||||
using OpRewritePattern<LoadOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(LoadOp loadOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(LoadOp loadOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Merges subview operation with store operation.
|
||||
|
@ -35,8 +35,8 @@ class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> {
|
|||
public:
|
||||
using OpRewritePattern<StoreOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(StoreOp storeOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(StoreOp storeOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -107,43 +107,43 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
|
|||
// Folding SubViewOp and LoadOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto subViewOp = dyn_cast_or_null<SubViewOp>(loadOp.memref().getDefiningOp());
|
||||
if (!subViewOp) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
SmallVector<Value, 4> sourceIndices;
|
||||
if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
|
||||
loadOp.indices(), sourceIndices)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
|
||||
sourceIndices);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Folding SubViewOp and StoreOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto subViewOp =
|
||||
dyn_cast_or_null<SubViewOp>(storeOp.memref().getDefiningOp());
|
||||
if (!subViewOp) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
SmallVector<Value, 4> sourceIndices;
|
||||
if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
|
||||
storeOp.indices(), sourceIndices)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
|
||||
subViewOp.source(), sourceIndices);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -133,13 +133,13 @@ public:
|
|||
: ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto broadcastOp = cast<vector::BroadcastOp>(op);
|
||||
VectorType dstVectorType = broadcastOp.getVectorType();
|
||||
if (typeConverter.convertType(dstVectorType) == nullptr)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
// Rewrite when the full vector type can be lowered (which
|
||||
// implies all 'reduced' types can be lowered too).
|
||||
auto adaptor = vector::BroadcastOpOperandAdaptor(operands);
|
||||
|
@ -149,7 +149,7 @@ public:
|
|||
op, expandRanks(adaptor.source(), // source value to be expanded
|
||||
op->getLoc(), // location of original broadcast
|
||||
srcVectorType, dstVectorType, rewriter));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -284,7 +284,7 @@ public:
|
|||
: ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto matmulOp = cast<vector::MatmulOp>(op);
|
||||
|
@ -293,7 +293,7 @@ public:
|
|||
op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
|
||||
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
|
||||
matmulOp.rhs_columns());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -304,7 +304,7 @@ public:
|
|||
: ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto reductionOp = cast<vector::ReductionOp>(op);
|
||||
|
@ -335,8 +335,8 @@ public:
|
|||
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
|
||||
op, llvmType, operands[0]);
|
||||
else
|
||||
return matchFailure();
|
||||
return matchSuccess();
|
||||
return failure();
|
||||
return success();
|
||||
|
||||
} else if (eltType.isF32() || eltType.isF64()) {
|
||||
// Floating-point reductions: add/mul/min/max
|
||||
|
@ -364,10 +364,10 @@ public:
|
|||
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
|
||||
op, llvmType, operands[0]);
|
||||
else
|
||||
return matchFailure();
|
||||
return matchSuccess();
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -378,7 +378,7 @@ public:
|
|||
: ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -392,7 +392,7 @@ public:
|
|||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Get rank and dimension sizes.
|
||||
int64_t rank = vectorType.getRank();
|
||||
|
@ -406,7 +406,7 @@ public:
|
|||
Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
|
||||
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
|
||||
rewriter.replaceOp(op, shuffle);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// For all other cases, insert the individual values individually.
|
||||
|
@ -425,7 +425,7 @@ public:
|
|||
llvmType, rank, insPos++);
|
||||
}
|
||||
rewriter.replaceOp(op, insert);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -436,7 +436,7 @@ public:
|
|||
: ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
|
||||
|
@ -446,11 +446,11 @@ public:
|
|||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
|
||||
op, llvmType, adaptor.vector(), adaptor.position());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -461,7 +461,7 @@ public:
|
|||
: ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -474,14 +474,14 @@ public:
|
|||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmResultType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// One-shot extraction of vector from array (only requires extractvalue).
|
||||
if (resultType.isa<VectorType>()) {
|
||||
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
|
||||
rewriter.replaceOp(op, extracted);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Potential extraction of 1-D vector from array.
|
||||
|
@ -505,7 +505,7 @@ public:
|
|||
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
|
||||
rewriter.replaceOp(op, extracted);
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -530,17 +530,17 @@ public:
|
|||
: ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::FMAOpOperandAdaptor(operands);
|
||||
vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
|
||||
VectorType vType = fmaOp.getVectorType();
|
||||
if (vType.getRank() != 1)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
|
||||
adaptor.acc());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -551,7 +551,7 @@ public:
|
|||
: ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
|
||||
|
@ -561,11 +561,11 @@ public:
|
|||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
|
||||
op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -576,7 +576,7 @@ public:
|
|||
: ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -589,7 +589,7 @@ public:
|
|||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmResultType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// One-shot insertion of a vector into an array (only requires insertvalue).
|
||||
if (sourceType.isa<VectorType>()) {
|
||||
|
@ -597,7 +597,7 @@ public:
|
|||
loc, llvmResultType, adaptor.dest(), adaptor.source(),
|
||||
positionArrayAttr);
|
||||
rewriter.replaceOp(op, inserted);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Potential extraction of 1-D vector from array.
|
||||
|
@ -632,7 +632,7 @@ public:
|
|||
}
|
||||
|
||||
rewriter.replaceOp(op, inserted);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -661,11 +661,11 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
|
|||
public:
|
||||
using OpRewritePattern<FMAOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(FMAOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(FMAOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto vType = op.getVectorType();
|
||||
if (vType.getRank() < 2)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto elemType = vType.getElementType();
|
||||
|
@ -680,7 +680,7 @@ public:
|
|||
desc = rewriter.create<InsertOp>(loc, fma, desc, i);
|
||||
}
|
||||
rewriter.replaceOp(op, desc);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -704,19 +704,19 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
|
|||
public:
|
||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getSourceVectorType();
|
||||
auto dstType = op.getDestVectorType();
|
||||
|
||||
if (op.offsets().getValue().empty())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
||||
assert(rankDiff >= 0);
|
||||
if (rankDiff == 0)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
int64_t rankRest = dstType.getRank() - rankDiff;
|
||||
// Extract / insert the subvector of matching rank and InsertStridedSlice
|
||||
|
@ -735,7 +735,7 @@ public:
|
|||
op, stridedSliceInnerOp.getResult(), op.dest(),
|
||||
getI64SubArray(op.offsets(), /*dropFront=*/0,
|
||||
/*dropFront=*/rankRest));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -753,22 +753,22 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
|
|||
public:
|
||||
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcType = op.getSourceVectorType();
|
||||
auto dstType = op.getDestVectorType();
|
||||
|
||||
if (op.offsets().getValue().empty())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
int64_t rankDiff = dstType.getRank() - srcType.getRank();
|
||||
assert(rankDiff >= 0);
|
||||
if (rankDiff != 0)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
if (srcType == dstType) {
|
||||
rewriter.replaceOp(op, op.source());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t offset =
|
||||
|
@ -813,7 +813,7 @@ public:
|
|||
}
|
||||
|
||||
rewriter.replaceOp(op, res);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -824,7 +824,7 @@ public:
|
|||
: ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
|
@ -837,18 +837,18 @@ public:
|
|||
// Only static shape casts supported atm.
|
||||
if (!sourceMemRefType.hasStaticShape() ||
|
||||
!targetMemRefType.hasStaticShape())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto llvmSourceDescriptorTy =
|
||||
operands[0].getType().dyn_cast<LLVM::LLVMType>();
|
||||
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
MemRefDescriptor sourceMemRef(operands[0]);
|
||||
|
||||
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
|
@ -866,7 +866,7 @@ public:
|
|||
}
|
||||
// Only contiguous source tensors supported atm.
|
||||
if (failed(successStrides) || !isContiguous)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
|
||||
|
||||
|
@ -901,7 +901,7 @@ public:
|
|||
}
|
||||
|
||||
rewriter.replaceOp(op, {desc});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -924,7 +924,7 @@ public:
|
|||
//
|
||||
// TODO(ajcbik): rely solely on libc in future? something else?
|
||||
//
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto printOp = cast<vector::PrintOp>(op);
|
||||
|
@ -932,7 +932,7 @@ public:
|
|||
Type printType = printOp.getPrintType();
|
||||
|
||||
if (typeConverter.convertType(printType) == nullptr)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Make sure element type has runtime support (currently just Float/Double).
|
||||
VectorType vectorType = printType.dyn_cast<VectorType>();
|
||||
|
@ -948,13 +948,13 @@ public:
|
|||
else if (eltType.isF64())
|
||||
printer = getPrintDouble(op);
|
||||
else
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Unroll vector into elementary print calls.
|
||||
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
|
||||
emitCall(rewriter, op->getLoc(), getPrintNewline(op));
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -1047,8 +1047,8 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
|
|||
public:
|
||||
using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(StridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(StridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dstType = op.getResult().getType().cast<VectorType>();
|
||||
|
||||
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
|
||||
|
@ -1089,7 +1089,7 @@ public:
|
|||
res = insertOne(rewriter, loc, extracted, res, idx);
|
||||
}
|
||||
rewriter.replaceOp(op, {res});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -198,8 +198,8 @@ struct VectorTransferRewriter : public RewritePattern {
|
|||
}
|
||||
|
||||
/// Performs the rewrite.
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Lowers TransferReadOp into a combination of:
|
||||
|
@ -246,7 +246,7 @@ struct VectorTransferRewriter : public RewritePattern {
|
|||
|
||||
/// Performs the rewrite.
|
||||
template <>
|
||||
PatternMatchResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
||||
LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
using namespace mlir::edsc::op;
|
||||
|
||||
|
@ -282,7 +282,7 @@ PatternMatchResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
|||
|
||||
// 3. Propagate.
|
||||
rewriter.replaceOp(op, vectorValue.getValue());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Lowers TransferWriteOp into a combination of:
|
||||
|
@ -304,7 +304,7 @@ PatternMatchResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
|||
/// TODO(ntv): implement alternatives to clipping.
|
||||
/// TODO(ntv): support non-data-parallel operations.
|
||||
template <>
|
||||
PatternMatchResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
||||
LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
using namespace edsc::op;
|
||||
|
||||
|
@ -340,7 +340,7 @@ PatternMatchResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
|||
(std_dealloc(tmp)); // vexing parse...
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -727,8 +727,8 @@ struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
|
|||
void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
|
||||
AffineMap map, ArrayRef<Value> mapOperands) const;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineOpTy affineOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineOpTy affineOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
static_assert(std::is_same<AffineOpTy, AffineLoadOp>::value ||
|
||||
std::is_same<AffineOpTy, AffinePrefetchOp>::value ||
|
||||
std::is_same<AffineOpTy, AffineStoreOp>::value ||
|
||||
|
@ -743,10 +743,10 @@ struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
|
|||
composeAffineMapAndOperands(&map, &resultOperands);
|
||||
if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
|
||||
resultOperands.begin()))
|
||||
return this->matchFailure();
|
||||
return failure();
|
||||
|
||||
replaceAffineOp(rewriter, affineOp, map, resultOperands);
|
||||
return this->matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1405,13 +1405,13 @@ namespace {
|
|||
struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
|
||||
using OpRewritePattern<AffineForOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineForOp forOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AffineForOp forOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that the body only contains a terminator.
|
||||
if (!has_single_element(*forOp.getBody()))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
rewriter.eraseOp(forOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
|
|
@ -111,8 +111,8 @@ namespace {
|
|||
struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
|
||||
using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(DequantizeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(DequantizeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Type inputType = op.arg().getType();
|
||||
Type outputType = op.getResult().getType();
|
||||
|
||||
|
@ -121,16 +121,16 @@ struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
|
|||
Type expressedOutputType = inputElementType.castToExpressedType(inputType);
|
||||
if (expressedOutputType != outputType) {
|
||||
// Not a valid uniform cast.
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
|
||||
if (!dequantizedValue) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, dequantizedValue);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -313,40 +313,40 @@ namespace {
|
|||
struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
|
||||
using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(RealAddEwOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(RealAddEwOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
|
||||
op.clamp_max());
|
||||
if (!info.isValid()) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Try all of the permutations we support.
|
||||
if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
|
||||
using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(RealMulEwOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(RealMulEwOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
|
||||
op.clamp_max());
|
||||
if (!info.isValid()) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Try all of the permutations we support.
|
||||
if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -380,8 +380,8 @@ struct GpuAllReduceConversion : public RewritePattern {
|
|||
explicit GpuAllReduceConversion(MLIRContext *context)
|
||||
: RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto funcOp = cast<gpu::GPUFuncOp>(op);
|
||||
auto callback = [&](gpu::AllReduceOp reduceOp) {
|
||||
GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
|
||||
|
@ -391,7 +391,7 @@ struct GpuAllReduceConversion : public RewritePattern {
|
|||
};
|
||||
while (funcOp.walk(callback).wasInterrupted()) {
|
||||
}
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -534,10 +534,10 @@ namespace {
|
|||
struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> {
|
||||
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(GenericOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(GenericOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op.hasTensorSemantics())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Find the first operand that is defined by another generic op on tensors.
|
||||
for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) {
|
||||
|
@ -551,9 +551,9 @@ struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> {
|
|||
if (!fusedOp)
|
||||
continue;
|
||||
rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -531,13 +531,13 @@ public:
|
|||
explicit LinalgRewritePattern(MLIRContext *context)
|
||||
: RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
using Impl = LinalgOpToLoopsImpl<LoopType, IndexedValueType, ConcreteOp>;
|
||||
if (failed(Impl::doit(op, rewriter)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -595,26 +595,26 @@ struct FoldAffineOp : public RewritePattern {
|
|||
FoldAffineOp(MLIRContext *context)
|
||||
: RewritePattern(AffineApplyOp::getOperationName(), 0, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op);
|
||||
auto map = affineApplyOp.getAffineMap();
|
||||
if (map.getNumResults() != 1 || map.getNumInputs() > 1)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
AffineExpr expr = map.getResult(0);
|
||||
if (map.getNumInputs() == 0) {
|
||||
if (auto val = expr.dyn_cast<AffineConstantExpr>()) {
|
||||
rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) {
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -30,8 +30,8 @@ public:
|
|||
struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
|
||||
using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(QuantizeCastOp qbarrier,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
@ -39,14 +39,14 @@ struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
|
|||
/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
|
||||
/// quantized and the operand type is quantizable.
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
|
||||
PatternRewriter &rewriter) const {
|
||||
Attribute value;
|
||||
|
||||
// Is the operand a constant?
|
||||
if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Does the qbarrier convert to a quantized type. This will not be true
|
||||
|
@ -56,10 +56,10 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
|
|||
QuantizedType quantizedElementType =
|
||||
QuantizedType::getQuantizedElementType(qbarrierResultType);
|
||||
if (!quantizedElementType) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
if (!QuantizedType::castToStorageType(qbarrierResultType)) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Is the operand type compatible with the expressed type of the quantized
|
||||
|
@ -67,20 +67,20 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
|
|||
// from and to a quantized type).
|
||||
if (!quantizedElementType.isCompatibleExpressedType(
|
||||
qbarrier.arg().getType())) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Is the constant value a type expressed in a way that we support?
|
||||
if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
|
||||
!value.isa<SparseElementsAttr>()) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
Type newConstValueType;
|
||||
auto newConstValue =
|
||||
quantizeAttr(value, quantizedElementType, newConstValueType);
|
||||
if (!newConstValue) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// When creating the new const op, use a fused location that combines the
|
||||
|
@ -92,7 +92,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
|
|||
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
|
||||
rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
|
||||
newConstOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
void ConvertConstPass::runOnFunction() {
|
||||
|
|
|
@ -35,16 +35,16 @@ public:
|
|||
FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
|
||||
: OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(FakeQuantOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(FakeQuantOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO: If this pattern comes up more frequently, consider adding core
|
||||
// support for failable rewrites.
|
||||
if (failableRewrite(op, rewriter)) {
|
||||
*hadFailure = true;
|
||||
return Pattern::matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
return Pattern::matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -88,13 +88,13 @@ struct CombineChainedAccessChain
|
|||
: public OpRewritePattern<spirv::AccessChainOp> {
|
||||
using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
|
||||
accessChainOp.base_ptr().getDefiningOp());
|
||||
|
||||
if (!parentAccessChainOp) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Combine indices.
|
||||
|
@ -105,7 +105,7 @@ struct CombineChainedAccessChain
|
|||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||
accessChainOp, parentAccessChainOp.base_ptr(), indices);
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
@ -291,24 +291,24 @@ struct ConvertSelectionOpToSelect
|
|||
: public OpRewritePattern<spirv::SelectionOp> {
|
||||
using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto *op = selectionOp.getOperation();
|
||||
auto &body = op->getRegion(0);
|
||||
// Verifier allows an empty region for `spv.selection`.
|
||||
if (body.empty()) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check that region consists of 4 blocks:
|
||||
// header block, `true` block, `false` block and merge block.
|
||||
if (std::distance(body.begin(), body.end()) != 4) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto *headerBlock = selectionOp.getHeaderBlock();
|
||||
if (!onlyContainsBranchConditionalOp(headerBlock)) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto brConditionalOp =
|
||||
|
@ -319,7 +319,7 @@ struct ConvertSelectionOpToSelect
|
|||
auto *mergeBlock = selectionOp.getMergeBlock();
|
||||
|
||||
if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto trueValue = getSrcValue(trueBlock);
|
||||
auto falseValue = getSrcValue(falseBlock);
|
||||
|
@ -335,7 +335,7 @@ struct ConvertSelectionOpToSelect
|
|||
|
||||
// `spv.selection` is not needed anymore.
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -345,9 +345,8 @@ private:
|
|||
// 2. Each `spv.Store` uses the same pointer and the same memory attributes.
|
||||
// 3. A control flow goes into the given merge block from the given
|
||||
// conditional blocks.
|
||||
PatternMatchResult canCanonicalizeSelection(Block *trueBlock,
|
||||
Block *falseBlock,
|
||||
Block *mergeBlock) const;
|
||||
LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
|
||||
Block *mergeBlock) const;
|
||||
|
||||
bool onlyContainsBranchConditionalOp(Block *block) const {
|
||||
return std::next(block->begin()) == block->end() &&
|
||||
|
@ -382,12 +381,12 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
|
||||
LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
|
||||
Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
|
||||
// Each block must consists of 2 operations.
|
||||
if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
|
||||
(std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
|
||||
|
@ -399,7 +398,7 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
|
|||
|
||||
if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
|
||||
!falseBrBranchOp) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check that each `spv.Store` uses the same pointer, memory access
|
||||
|
@ -407,15 +406,15 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
|
|||
if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
|
||||
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
|
||||
!isValidType(trueBrStoreOp.value().getType())) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) ||
|
||||
(falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
} // end anonymous namespace
|
||||
|
||||
|
|
|
@ -177,25 +177,25 @@ class FuncOpConversion final : public SPIRVOpLowering<FuncOp> {
|
|||
public:
|
||||
using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto fnType = funcOp.getType();
|
||||
// TODO(antiagainst): support converting functions with one result.
|
||||
if (fnType.getNumResults())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
|
||||
for (auto argType : enumerate(funcOp.getType().getInputs())) {
|
||||
auto convertedType = typeConverter.convertType(argType.value());
|
||||
if (!convertedType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
signatureConverter.addInputs(argType.index(), convertedType);
|
||||
}
|
||||
|
||||
|
@ -216,7 +216,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
|
|||
newFuncOp.end());
|
||||
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
|
||||
rewriter.eraseOp(funcOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::populateBuiltinFuncToSPIRVPatterns(
|
||||
|
|
|
@ -27,8 +27,8 @@ class SPIRVGlobalVariableOpLayoutInfoDecoration
|
|||
public:
|
||||
using OpRewritePattern<spirv::GlobalVariableOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(spirv::GlobalVariableOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(spirv::GlobalVariableOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
spirv::StructType::LayoutInfo structSize = 0;
|
||||
VulkanLayoutUtils::Size structAlignment = 1;
|
||||
SmallVector<NamedAttribute, 4> globalVarAttrs;
|
||||
|
@ -50,7 +50,7 @@ public:
|
|||
|
||||
rewriter.replaceOpWithNewOp<spirv::GlobalVariableOp>(
|
||||
op, TypeAttr::get(decoratedType), globalVarAttrs);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -59,15 +59,15 @@ class SPIRVAddressOfOpLayoutInfoDecoration
|
|||
public:
|
||||
using OpRewritePattern<spirv::AddressOfOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(spirv::AddressOfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(spirv::AddressOfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto spirvModule = op.getParentOfType<spirv::ModuleOp>();
|
||||
auto varName = op.variable();
|
||||
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
|
||||
op, varOp.type(), rewriter.getSymbolRefAttr(varName));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -138,7 +138,7 @@ namespace {
|
|||
class ProcessInterfaceVarABI final : public SPIRVOpLowering<spirv::FuncOp> {
|
||||
public:
|
||||
using SPIRVOpLowering<spirv::FuncOp>::SPIRVOpLowering;
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
@ -151,13 +151,13 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
PatternMatchResult ProcessInterfaceVarABI::matchAndRewrite(
|
||||
LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
|
||||
spirv::FuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (!funcOp.getAttrOfType<spirv::EntryPointABIAttr>(
|
||||
spirv::getEntryPointABIAttrName())) {
|
||||
// TODO(ravishankarm) : Non-entry point functions are not handled.
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
TypeConverter::SignatureConversion signatureConverter(
|
||||
funcOp.getType().getNumInputs());
|
||||
|
@ -171,12 +171,12 @@ PatternMatchResult ProcessInterfaceVarABI::matchAndRewrite(
|
|||
// to pass around scalar/vector values and return a scalar/vector. For now
|
||||
// non-entry point functions are not handled in this ABI lowering and will
|
||||
// produce an error.
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
auto var =
|
||||
createGlobalVariableForArg(funcOp, rewriter, argType.index(), abiInfo);
|
||||
if (!var) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
|
||||
|
@ -207,7 +207,7 @@ PatternMatchResult ProcessInterfaceVarABI::matchAndRewrite(
|
|||
signatureConverter.getConvertedTypes(), llvm::None));
|
||||
rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
|
||||
});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
void LowerABIAttributesPass::runOnOperation() {
|
||||
|
|
|
@ -313,14 +313,14 @@ namespace {
|
|||
struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
|
||||
using OpRewritePattern<AllocOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AllocOp alloc,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AllocOp alloc,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check to see if any dimensions operands are constants. If so, we can
|
||||
// substitute and drop them.
|
||||
if (llvm::none_of(alloc.getOperands(), [](Value operand) {
|
||||
return matchPattern(operand, m_ConstantIndex());
|
||||
}))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto memrefType = alloc.getType();
|
||||
|
||||
|
@ -364,7 +364,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
|
|||
alloc.getType());
|
||||
|
||||
rewriter.replaceOp(alloc, {resultCast});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -373,13 +373,13 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
|
|||
struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
|
||||
using OpRewritePattern<AllocOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AllocOp alloc,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AllocOp alloc,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (alloc.use_empty()) {
|
||||
rewriter.eraseOp(alloc);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
@ -461,18 +461,18 @@ namespace {
|
|||
struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
|
||||
using OpRewritePattern<BranchOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(BranchOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(BranchOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that the successor block has a single predecessor.
|
||||
Block *succ = op.getDest();
|
||||
Block *opParent = op.getOperation()->getBlock();
|
||||
if (succ == opParent || !has_single_element(succ->getPredecessors()))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Merge the successor into the current block and erase the branch.
|
||||
rewriter.mergeBlocks(succ, opParent, op.getOperands());
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
@ -545,18 +545,18 @@ struct SimplifyIndirectCallWithKnownCallee
|
|||
: public OpRewritePattern<CallIndirectOp> {
|
||||
using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(CallIndirectOp indirectCall,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that the callee is a constant callee.
|
||||
SymbolRefAttr calledFn;
|
||||
if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Replace with a direct call.
|
||||
rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
|
||||
indirectCall.getResultTypes(),
|
||||
indirectCall.getArgOperands());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
@ -733,20 +733,20 @@ namespace {
|
|||
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
|
||||
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (matchPattern(condbr.getCondition(), m_NonZero())) {
|
||||
// True branch taken.
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
||||
condbr.getTrueOperands());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
} else if (matchPattern(condbr.getCondition(), m_Zero())) {
|
||||
// False branch taken.
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
||||
condbr.getFalseOperands());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
@ -958,21 +958,21 @@ namespace {
|
|||
struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
|
||||
using OpRewritePattern<DeallocOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(DeallocOp dealloc,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(DeallocOp dealloc,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check that the memref operand's defining operation is an AllocOp.
|
||||
Value memref = dealloc.memref();
|
||||
if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Check that all of the uses of the AllocOp are other DeallocOps.
|
||||
for (auto *user : memref.getUsers())
|
||||
if (!isa<DeallocOp>(user))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Erase the dealloc operation.
|
||||
rewriter.eraseOp(dealloc);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
@ -2003,8 +2003,8 @@ class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> {
|
|||
public:
|
||||
using OpRewritePattern<SubViewOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(SubViewOp subViewOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
MemRefType subViewType = subViewOp.getType();
|
||||
// Follow all or nothing approach for shapes for now. If all the operands
|
||||
// for sizes are constants then fold it into the type of the result memref.
|
||||
|
@ -2012,7 +2012,7 @@ public:
|
|||
llvm::any_of(subViewOp.sizes(), [](Value operand) {
|
||||
return !matchPattern(operand, m_ConstantIndex());
|
||||
})) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
|
||||
for (auto size : llvm::enumerate(subViewOp.sizes())) {
|
||||
|
@ -2028,7 +2028,7 @@ public:
|
|||
// Insert a memref_cast for compatibility of the uses of the op.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
|
||||
subViewOp.getType());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2037,10 +2037,10 @@ class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> {
|
|||
public:
|
||||
using OpRewritePattern<SubViewOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(SubViewOp subViewOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (subViewOp.getNumStrides() == 0) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
// Follow all or nothing approach for strides for now. If all the operands
|
||||
// for strides are constants then fold it into the strides of the result
|
||||
|
@ -2056,7 +2056,7 @@ public:
|
|||
llvm::any_of(subViewOp.strides(), [](Value stride) {
|
||||
return !matchPattern(stride, m_ConstantIndex());
|
||||
})) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
|
||||
|
@ -2077,7 +2077,7 @@ public:
|
|||
// Insert a memref_cast for compatibility of the uses of the op.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
|
||||
subViewOp.getType());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2086,10 +2086,10 @@ class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
|
|||
public:
|
||||
using OpRewritePattern<SubViewOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(SubViewOp subViewOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (subViewOp.getNumOffsets() == 0) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
// Follow all or nothing approach for offsets for now. If all the operands
|
||||
// for offsets are constants then fold it into the offset of the result
|
||||
|
@ -2106,7 +2106,7 @@ public:
|
|||
llvm::any_of(subViewOp.offsets(), [](Value stride) {
|
||||
return !matchPattern(stride, m_ConstantIndex());
|
||||
})) {
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto staticOffset = baseOffset;
|
||||
|
@ -2128,7 +2128,7 @@ public:
|
|||
// Insert a memref_cast for compatibility of the uses of the op.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
|
||||
subViewOp.getType());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2347,18 +2347,18 @@ namespace {
|
|||
struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
||||
using OpRewritePattern<ViewOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(ViewOp viewOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(ViewOp viewOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Return if none of the operands are constants.
|
||||
if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
|
||||
return matchPattern(operand, m_ConstantIndex());
|
||||
}))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Get result memref type.
|
||||
auto memrefType = viewOp.getType();
|
||||
if (memrefType.getAffineMaps().size() > 1)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
auto map = memrefType.getAffineMaps().empty()
|
||||
? AffineMap::getMultiDimIdentityMap(memrefType.getRank(),
|
||||
rewriter.getContext())
|
||||
|
@ -2368,7 +2368,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||
int64_t oldOffset;
|
||||
SmallVector<int64_t, 4> oldStrides;
|
||||
if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
SmallVector<Value, 4> newOperands;
|
||||
|
||||
|
@ -2444,27 +2444,27 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||
// Insert a cast so we have the same type as the old memref type.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
|
||||
viewOp.getType());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
|
||||
using OpRewritePattern<ViewOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(ViewOp viewOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(ViewOp viewOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value memrefOperand = viewOp.getOperand(0);
|
||||
MemRefCastOp memrefCastOp =
|
||||
dyn_cast_or_null<MemRefCastOp>(memrefOperand.getDefiningOp());
|
||||
if (!memrefCastOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
Value allocOperand = memrefCastOp.getOperand();
|
||||
AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
|
||||
if (!allocOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
|
||||
viewOp.operands());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1145,18 +1145,18 @@ class StridedSliceConstantMaskFolder final
|
|||
public:
|
||||
using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(StridedSliceOp stridedSliceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(StridedSliceOp stridedSliceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Return if 'stridedSliceOp' operand is not defined by a ConstantMaskOp.
|
||||
auto defOp = stridedSliceOp.vector().getDefiningOp();
|
||||
auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
|
||||
if (!constantMaskOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
// Return if 'stridedSliceOp' has non-unit strides.
|
||||
if (llvm::any_of(stridedSliceOp.strides(), [](Attribute attr) {
|
||||
return attr.cast<IntegerAttr>().getInt() != 1;
|
||||
}))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
// Gather constant mask dimension sizes.
|
||||
SmallVector<int64_t, 4> maskDimSizes;
|
||||
populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
|
||||
|
@ -1187,7 +1187,7 @@ public:
|
|||
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
|
||||
stridedSliceOp, stridedSliceOp.getResult().getType(),
|
||||
vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1619,14 +1619,14 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
|
|||
public:
|
||||
using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Return if any of 'createMaskOp' operands are not defined by a constant.
|
||||
auto is_not_def_by_constant = [](Value operand) {
|
||||
return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
|
||||
};
|
||||
if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
// Gather constant mask dimension sizes.
|
||||
SmallVector<int64_t, 4> maskDimSizes;
|
||||
for (auto operand : createMaskOp.operands()) {
|
||||
|
@ -1637,7 +1637,7 @@ public:
|
|||
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
|
||||
createMaskOp, createMaskOp.getResult().getType(),
|
||||
vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -545,18 +545,18 @@ namespace {
|
|||
struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
|
||||
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::TransferReadOp xferReadOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::TransferReadOp xferReadOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO(andydavis, ntv) Support splitting TransferReadOp with non-identity
|
||||
// permutation maps. Repurpose code from MaterializeVectors transformation.
|
||||
if (!isIdentitySuffix(xferReadOp.permutation_map()))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
// Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
|
||||
Value xferReadResult = xferReadOp.getResult();
|
||||
auto extractSlicesOp =
|
||||
dyn_cast<vector::ExtractSlicesOp>(*xferReadResult.getUsers().begin());
|
||||
if (!xferReadResult.hasOneUse() || !extractSlicesOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
|
||||
auto sourceVectorType = extractSlicesOp.getSourceVectorType();
|
||||
|
@ -593,7 +593,7 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
|
|||
rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>(
|
||||
xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(),
|
||||
extractSlicesOp.strides());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -601,23 +601,23 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
|
|||
struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
|
||||
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO(andydavis, ntv) Support splitting TransferWriteOp with non-identity
|
||||
// permutation maps. Repurpose code from MaterializeVectors transformation.
|
||||
if (!isIdentitySuffix(xferWriteOp.permutation_map()))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
// Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
|
||||
auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
|
||||
auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
|
||||
if (!insertSlicesOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Get TupleOp operand of 'insertSlicesOp'.
|
||||
auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
|
||||
insertSlicesOp.vectors().getDefiningOp());
|
||||
if (!tupleOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Get 'sizes' and 'strides' parameters from InsertSlicesOp user.
|
||||
auto sourceTupleType = insertSlicesOp.getSourceTupleType();
|
||||
|
@ -644,7 +644,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
|
|||
|
||||
// Erase old 'xferWriteOp'.
|
||||
rewriter.eraseOp(xferWriteOp);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -653,15 +653,15 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
|
|||
struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
|
||||
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check if 'shapeCastOp' has tuple source/result type.
|
||||
auto sourceTupleType =
|
||||
shapeCastOp.source().getType().dyn_cast_or_null<TupleType>();
|
||||
auto resultTupleType =
|
||||
shapeCastOp.result().getType().dyn_cast_or_null<TupleType>();
|
||||
if (!sourceTupleType || !resultTupleType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
assert(sourceTupleType.size() == resultTupleType.size());
|
||||
|
||||
// Create single-vector ShapeCastOp for each source tuple element.
|
||||
|
@ -679,7 +679,7 @@ struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
|
|||
// Replace 'shapeCastOp' with tuple of 'resultElements'.
|
||||
rewriter.replaceOpWithNewOp<vector::TupleOp>(shapeCastOp, resultTupleType,
|
||||
resultElements);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -702,21 +702,21 @@ struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
|
|||
struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
|
||||
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Check if 'shapeCastOp' has vector source/result type.
|
||||
auto sourceVectorType =
|
||||
shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
|
||||
auto resultVectorType =
|
||||
shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
|
||||
if (!sourceVectorType || !resultVectorType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Check if shape cast op source operand is also a shape cast op.
|
||||
auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
|
||||
shapeCastOp.source().getDefiningOp());
|
||||
if (!sourceShapeCastOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
auto operandSourceVectorType =
|
||||
sourceShapeCastOp.source().getType().cast<VectorType>();
|
||||
auto operandResultVectorType =
|
||||
|
@ -725,10 +725,10 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
|
|||
// Check if shape cast operations invert each other.
|
||||
if (operandSourceVectorType != resultVectorType ||
|
||||
operandResultVectorType != sourceVectorType)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -738,30 +738,30 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
|
|||
struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
|
||||
using OpRewritePattern<vector::TupleGetOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp.
|
||||
auto extractSlicesOp = dyn_cast_or_null<vector::ExtractSlicesOp>(
|
||||
tupleGetOp.vectors().getDefiningOp());
|
||||
if (!extractSlicesOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp.
|
||||
auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(
|
||||
extractSlicesOp.vector().getDefiningOp());
|
||||
if (!insertSlicesOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp.
|
||||
auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
|
||||
insertSlicesOp.vectors().getDefiningOp());
|
||||
if (!tupleOp)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Forward Value from 'tupleOp' at 'tupleGetOp.index'.
|
||||
Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
|
||||
rewriter.replaceOp(tupleGetOp, tupleValue);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -778,8 +778,8 @@ class ExtractSlicesOpLowering
|
|||
public:
|
||||
using OpRewritePattern<vector::ExtractSlicesOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::ExtractSlicesOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::ExtractSlicesOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
VectorType vectorType = op.getSourceVectorType();
|
||||
|
@ -806,7 +806,7 @@ public:
|
|||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::TupleOp>(op, tupleType, tupleValues);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -825,8 +825,8 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
|
|||
public:
|
||||
using OpRewritePattern<vector::InsertSlicesOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::InsertSlicesOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::InsertSlicesOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
VectorType vectorType = op.getResultVectorType();
|
||||
|
@ -860,7 +860,7 @@ public:
|
|||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -881,8 +881,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
|
|||
public:
|
||||
using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::OuterProductOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::OuterProductOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
VectorType rhsType = op.getOperandVectorTypeRHS();
|
||||
|
@ -907,7 +907,7 @@ public:
|
|||
result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -934,11 +934,11 @@ public:
|
|||
: OpRewritePattern<vector::ContractionOp>(context),
|
||||
vectorTransformsOptions(vectorTransformsOptions) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO(ajcbik): implement masks
|
||||
if (llvm::size(op.masks()) != 0)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// TODO(ntv, ajcbik): implement benefits, cost models, separate this out in
|
||||
// a new pattern.
|
||||
|
@ -977,7 +977,7 @@ public:
|
|||
rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
|
||||
else
|
||||
rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -987,7 +987,7 @@ public:
|
|||
int64_t lhsIndex = batchDimMap[0].first;
|
||||
int64_t rhsIndex = batchDimMap[0].second;
|
||||
rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Collect contracting dimensions.
|
||||
|
@ -1007,7 +1007,7 @@ public:
|
|||
if (lhsContractingDimSet.count(lhsIndex) == 0) {
|
||||
rewriter.replaceOp(
|
||||
op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1018,17 +1018,17 @@ public:
|
|||
if (rhsContractingDimSet.count(rhsIndex) == 0) {
|
||||
rewriter.replaceOp(
|
||||
op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
// Lower the first remaining reduction dimension.
|
||||
if (!contractingDimMap.empty()) {
|
||||
rewriter.replaceOp(op, lowerReduction(op, rewriter));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -1275,12 +1275,12 @@ class ShapeCastOp2DDownCastRewritePattern
|
|||
public:
|
||||
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::ShapeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto sourceVectorType = op.getSourceVectorType();
|
||||
auto resultVectorType = op.getResultVectorType();
|
||||
if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto elemType = sourceVectorType.getElementType();
|
||||
|
@ -1295,7 +1295,7 @@ public:
|
|||
/*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
|
||||
}
|
||||
rewriter.replaceOp(op, desc);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1309,12 +1309,12 @@ class ShapeCastOp2DUpCastRewritePattern
|
|||
public:
|
||||
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(vector::ShapeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto sourceVectorType = op.getSourceVectorType();
|
||||
auto resultVectorType = op.getResultVectorType();
|
||||
if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto elemType = sourceVectorType.getElementType();
|
||||
|
@ -1330,7 +1330,7 @@ public:
|
|||
desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
|
||||
}
|
||||
rewriter.replaceOp(op, desc);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
|
|||
"rewrite functions!");
|
||||
}
|
||||
|
||||
PatternMatchResult RewritePattern::match(Operation *op) const {
|
||||
LogicalResult RewritePattern::match(Operation *op) const {
|
||||
llvm_unreachable("need to implement either match or matchAndRewrite!");
|
||||
}
|
||||
|
||||
|
|
|
@ -35,13 +35,13 @@ public:
|
|||
RemoveIdentityOpRewrite(MLIRContext *context)
|
||||
: RewritePattern(OpTy::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
assert(op->getNumOperands() == 1);
|
||||
assert(op->getNumResults() == 1);
|
||||
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1010,7 +1010,7 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Attempt to match and rewrite the IR root at the specified operation.
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
ConversionPattern::matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
SmallVector<Value, 4> operands;
|
||||
|
@ -1705,7 +1705,7 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
|
|||
: OpConversionPattern(ctx), converter(converter) {}
|
||||
|
||||
/// Hook for derived classes to implement combined matching and rewriting.
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
FunctionType type = funcOp.getType();
|
||||
|
@ -1714,12 +1714,12 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
|
|||
TypeConverter::SignatureConversion result(type.getNumInputs());
|
||||
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
|
||||
if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Convert the original function results.
|
||||
SmallVector<Type, 1> convertedResults;
|
||||
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Update the function signature in-place.
|
||||
rewriter.updateRootInPlace(funcOp, [&] {
|
||||
|
@ -1727,7 +1727,7 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
|
|||
convertedResults, funcOp.getContext()));
|
||||
rewriter.applySignatureConversion(&funcOp.getBody(), result);
|
||||
});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
/// The type converter to use when rewriting the signature.
|
||||
|
|
|
@ -94,32 +94,32 @@ struct ConvertToTargetEnv : public FunctionPass<ConvertToTargetEnv> {
|
|||
|
||||
struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
|
||||
ConvertToAtomCmpExchangeWeak(MLIRContext *context);
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct ConvertToBitReverse : public RewritePattern {
|
||||
ConvertToBitReverse(MLIRContext *context);
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct ConvertToGroupNonUniformBallot : public RewritePattern {
|
||||
ConvertToGroupNonUniformBallot(MLIRContext *context);
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct ConvertToModule : public RewritePattern {
|
||||
ConvertToModule(MLIRContext *context);
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct ConvertToSubgroupBallot : public RewritePattern {
|
||||
ConvertToSubgroupBallot(MLIRContext *context);
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -145,7 +145,7 @@ ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
|
|||
: RewritePattern("test.convert_to_atomic_compare_exchange_weak_op",
|
||||
{"spv.AtomicCompareExchangeWeak"}, 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
Value ptr = op->getOperand(0);
|
||||
|
@ -159,21 +159,21 @@ ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
|
|||
spirv::MemorySemantics::AcquireRelease |
|
||||
spirv::MemorySemantics::AtomicCounterMemory,
|
||||
spirv::MemorySemantics::Acquire, value, comparator);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
|
||||
: RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1,
|
||||
context) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
ConvertToBitReverse::matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
Value predicate = op->getOperand(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
|
||||
op, op->getResult(0).getType(), predicate);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
|
||||
|
@ -181,39 +181,39 @@ ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
|
|||
: RewritePattern("test.convert_to_group_non_uniform_ballot_op",
|
||||
{"spv.GroupNonUniformBallot"}, 1, context) {}
|
||||
|
||||
PatternMatchResult ConvertToGroupNonUniformBallot::matchAndRewrite(
|
||||
LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter) const {
|
||||
Value predicate = op->getOperand(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
|
||||
op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
ConvertToModule::ConvertToModule(MLIRContext *context)
|
||||
: RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
ConvertToModule::matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
|
||||
op, spirv::AddressingModel::PhysicalStorageBuffer64,
|
||||
spirv::MemoryModel::Vulkan);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
|
||||
: RewritePattern("test.convert_to_subgroup_ballot_op",
|
||||
{"spv.SubgroupBallotKHR"}, 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
Value predicate = op->getOperand(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
|
||||
op, op->getResult(0).getType(), predicate);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -283,10 +283,10 @@ struct TestRemoveOpWithInnerOps
|
|||
: public OpRewritePattern<TestOpWithRegionPattern> {
|
||||
using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(TestOpWithRegionPattern op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
|
|
@ -141,7 +141,7 @@ struct TestRegionRewriteBlockMovement : public ConversionPattern {
|
|||
TestRegionRewriteBlockMovement(MLIRContext *ctx)
|
||||
: ConversionPattern("test.region", 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// Inline this region into the parent region.
|
||||
|
@ -155,7 +155,7 @@ struct TestRegionRewriteBlockMovement : public ConversionPattern {
|
|||
|
||||
// Drop this operation.
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
/// This pattern is a simple pattern that generates a region containing an
|
||||
|
@ -164,8 +164,8 @@ struct TestRegionRewriteUndo : public RewritePattern {
|
|||
TestRegionRewriteUndo(MLIRContext *ctx)
|
||||
: RewritePattern("test.region_builder", 1, ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
// Create the region operation with an entry block containing arguments.
|
||||
OperationState newRegion(op->getLoc(), "test.region");
|
||||
newRegion.addRegion();
|
||||
|
@ -179,7 +179,7 @@ struct TestRegionRewriteUndo : public RewritePattern {
|
|||
|
||||
// Drop this operation.
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -191,7 +191,7 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
|
|||
TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
|
||||
: ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) {
|
||||
}
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Region ®ion = op->getRegion(0);
|
||||
|
@ -202,12 +202,12 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
|
|||
for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i)
|
||||
if (failed(converter.convertSignatureArg(
|
||||
i, entry->getArgument(i).getType(), result)))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Convert the region signature and just drop the operation.
|
||||
rewriter.applySignatureConversion(®ion, result);
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
/// The type converter to use when rewriting the signature.
|
||||
|
@ -217,35 +217,35 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
|
|||
struct TestPassthroughInvalidOp : public ConversionPattern {
|
||||
TestPassthroughInvalidOp(MLIRContext *ctx)
|
||||
: ConversionPattern("test.invalid", 1, ctx) {}
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
|
||||
llvm::None);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
/// This pattern handles the case of a split return value.
|
||||
struct TestSplitReturnType : public ConversionPattern {
|
||||
TestSplitReturnType(MLIRContext *ctx)
|
||||
: ConversionPattern("test.return", 1, ctx) {}
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// Check for a return of F32.
|
||||
if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
|
||||
// Check if the first operation is a cast operation, if it is we use the
|
||||
// results directly.
|
||||
auto *defOp = operands[0].getDefiningOp();
|
||||
if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
|
||||
rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Otherwise, fail to match.
|
||||
return matchFailure();
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -254,52 +254,52 @@ struct TestSplitReturnType : public ConversionPattern {
|
|||
struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
|
||||
TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
|
||||
: ConversionPattern("test.type_producer", 1, ctx) {}
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// If the type is I32, change the type to F32.
|
||||
if (!Type(*op->result_type_begin()).isSignlessInteger(32))
|
||||
return matchFailure();
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
|
||||
TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
|
||||
: ConversionPattern("test.type_producer", 1, ctx) {}
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// If the type is F32, change the type to F64.
|
||||
if (!Type(*op->result_type_begin()).isF32())
|
||||
return rewriter.notifyMatchFailure(op, "expected single f32 operand");
|
||||
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
|
||||
TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
|
||||
: ConversionPattern("test.type_producer", 10, ctx) {}
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// Always convert to B16, even though it is not a legal type. This tests
|
||||
// that values are unmapped correctly.
|
||||
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
struct TestUpdateConsumerType : public ConversionPattern {
|
||||
TestUpdateConsumerType(MLIRContext *ctx)
|
||||
: ConversionPattern("test.type_consumer", 1, ctx) {}
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// Verify that the incoming operand has been successfully remapped to F64.
|
||||
if (!operands[0].getType().isF64())
|
||||
return matchFailure();
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -312,15 +312,15 @@ struct TestNonRootReplacement : public RewritePattern {
|
|||
TestNonRootReplacement(MLIRContext *ctx)
|
||||
: RewritePattern("test.replace_non_root", 1, ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto resultType = *op->result_type_begin();
|
||||
auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
|
||||
auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
|
||||
|
||||
rewriter.replaceOp(illegalOp, {legalOp});
|
||||
rewriter.replaceOp(op, {illegalOp});
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -475,7 +475,7 @@ struct OneVResOneVOperandOp1Converter
|
|||
: public OpConversionPattern<OneVResOneVOperandOp1> {
|
||||
using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
|
||||
|
||||
PatternMatchResult
|
||||
LogicalResult
|
||||
matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto origOps = op.getOperands();
|
||||
|
@ -490,7 +490,7 @@ struct OneVResOneVOperandOp1Converter
|
|||
|
||||
rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
|
||||
remappedOperands);
|
||||
return matchSuccess();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -215,7 +215,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
|
|||
// Skip the operand matching at depth 0 as the pattern rewriter already does.
|
||||
if (depth != 0) {
|
||||
// Skip if there is no defining operation (e.g., arguments to function).
|
||||
os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
|
||||
os.indent(indent) << formatv("if (!castedOp{0}) return failure();\n",
|
||||
depth);
|
||||
}
|
||||
if (tree.getNumArgs() != op.getNumArgs()) {
|
||||
|
@ -300,7 +300,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
|
|||
os.indent(indent) << "if (!("
|
||||
<< std::string(tgfmt(matcher.getConditionTemplate(),
|
||||
&fmtCtx.withSelf(self)))
|
||||
<< ")) return matchFailure();\n";
|
||||
<< ")) return failure();\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -344,7 +344,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
|
|||
// should just capture a mlir::Attribute() to signal the missing state.
|
||||
// That is precisely what getAttr() returns on missing attributes.
|
||||
} else {
|
||||
os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
|
||||
os.indent(indent) << "if (!tblgen_attr) return failure();\n";
|
||||
}
|
||||
|
||||
auto matcher = tree.getArgAsLeaf(argIndex);
|
||||
|
@ -360,7 +360,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
|
|||
os.indent(indent) << "if (!("
|
||||
<< std::string(tgfmt(matcher.getConditionTemplate(),
|
||||
&fmtCtx.withSelf("tblgen_attr")))
|
||||
<< ")) return matchFailure();\n";
|
||||
<< ")) return failure();\n";
|
||||
}
|
||||
|
||||
// Capture the value
|
||||
|
@ -383,7 +383,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
|
|||
auto &entities = appliedConstraint.entities;
|
||||
|
||||
auto condition = constraint.getConditionTemplate();
|
||||
auto cmd = "if (!({0})) return matchFailure();\n";
|
||||
auto cmd = "if (!({0})) return failure();\n";
|
||||
|
||||
if (isa<TypeConstraint>(constraint)) {
|
||||
auto self = formatv("({0}.getType())",
|
||||
|
@ -468,7 +468,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
|
||||
// Emit matchAndRewrite() function.
|
||||
os << R"(
|
||||
PatternMatchResult matchAndRewrite(Operation *op0,
|
||||
LogicalResult matchAndRewrite(Operation *op0,
|
||||
PatternRewriter &rewriter) const override {
|
||||
)";
|
||||
|
||||
|
@ -501,7 +501,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
os.indent(4) << "// Rewrite\n";
|
||||
emitRewriteLogic();
|
||||
|
||||
os.indent(4) << "return matchSuccess();\n";
|
||||
os.indent(4) << "return success();\n";
|
||||
os << " };\n";
|
||||
os << "};\n";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue