[mlir][NFC] Replace all usages of PatternMatchResult with LogicalResult

This also replaces usages of matchSuccess/matchFailure with success/failure respectively.

Differential Revision: https://reviews.llvm.org/D76313
This commit is contained in:
River Riddle 2020-03-17 20:07:55 -07:00
parent 2fae7878d5
commit 3145427dd7
52 changed files with 722 additions and 743 deletions

View File

@ -247,7 +247,7 @@ struct MyConversionPattern : public ConversionPattern {
/// The `matchAndRewrite` hooks on ConversionPatterns take an additional
/// `operands` parameter, containing the remapped operands of the original
/// operation.
virtual PatternMatchResult
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const;
};

View File

@ -171,8 +171,8 @@ struct ConvertTFLeakyRelu : public RewritePattern {
ConvertTFLeakyRelu(MLIRContext *context)
: RewritePattern("tf.LeakyRelu", 1, context) {}
PatternMatchResult match(Operation *op) const override {
return matchSuccess();
LogicalResult match(Operation *op) const override {
return success();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
@ -188,12 +188,12 @@ struct ConvertTFLeakyRelu : public RewritePattern {
ConvertTFLeakyRelu(MLIRContext *context)
: RewritePattern("tf.LeakyRelu", 1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
op, op->getResult(0).getType(), op->getOperand(0),
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
return matchSuccess();
return success();
}
};
```

View File

@ -86,7 +86,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// This method is attempting to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected
/// to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
@ -96,11 +96,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
return matchFailure();
return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
return matchSuccess();
return success();
}
};
```

View File

@ -106,7 +106,7 @@ struct TransposeOpLowering : public mlir::ConversionPattern {
/// Match and rewrite the given `toy.transpose` operation, with the given
/// operands that have been remapped from `tensor<...>` to `memref<...>`.
mlir::PatternMatchResult
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@ -132,7 +132,7 @@ struct TransposeOpLowering : public mlir::ConversionPattern {
SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<mlir::AffineLoadOp>(loc, input, reverseIvs);
});
return matchSuccess();
return success();
}
};
```

View File

@ -35,7 +35,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
@ -45,11 +45,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
return matchFailure();
return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
return matchSuccess();
return success();
}
};

View File

@ -40,7 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
@ -50,11 +50,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
return matchFailure();
return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
return matchSuccess();
return success();
}
};

View File

@ -103,7 +103,7 @@ struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext *ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@ -126,7 +126,7 @@ struct BinaryOpLowering : public ConversionPattern {
// Create the binary operation performed on the loaded values.
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
});
return matchSuccess();
return success();
}
};
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
@ -139,8 +139,8 @@ using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
Location loc = op.getLoc();
@ -189,7 +189,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
return matchSuccess();
return success();
}
};
@ -200,16 +200,16 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(toy::ReturnOp op,
PatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(toy::ReturnOp op,
PatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
return matchFailure();
return failure();
// We lower "toy.return" directly to "std.return".
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
return success();
}
};
@ -221,7 +221,7 @@ struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@ -240,7 +240,7 @@ struct TransposeOpLowering : public ConversionPattern {
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
});
return matchSuccess();
return success();
}
};

View File

@ -40,7 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
@ -50,11 +50,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
return matchFailure();
return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
return matchSuccess();
return success();
}
};

View File

@ -103,7 +103,7 @@ struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext *ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@ -126,7 +126,7 @@ struct BinaryOpLowering : public ConversionPattern {
// Create the binary operation performed on the loaded values.
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
});
return matchSuccess();
return success();
}
};
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
@ -139,8 +139,8 @@ using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
Location loc = op.getLoc();
@ -189,7 +189,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
return matchSuccess();
return success();
}
};
@ -200,16 +200,16 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(toy::ReturnOp op,
PatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(toy::ReturnOp op,
PatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
return matchFailure();
return failure();
// We lower "toy.return" directly to "std.return".
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
return success();
}
};
@ -221,7 +221,7 @@ struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@ -240,7 +240,7 @@ struct TransposeOpLowering : public ConversionPattern {
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
});
return matchSuccess();
return success();
}
};

View File

@ -41,7 +41,7 @@ public:
explicit PrintOpLowering(MLIRContext *context)
: ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
@ -91,7 +91,7 @@ public:
// Notify the rewriter that this operation has been removed.
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
private:

View File

@ -40,7 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
@ -50,11 +50,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
return matchFailure();
return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
return matchSuccess();
return success();
}
};

View File

@ -103,7 +103,7 @@ struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext *ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@ -126,7 +126,7 @@ struct BinaryOpLowering : public ConversionPattern {
// Create the binary operation performed on the loaded values.
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
});
return matchSuccess();
return success();
}
};
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
@ -139,8 +139,8 @@ using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
Location loc = op.getLoc();
@ -189,7 +189,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
return matchSuccess();
return success();
}
};
@ -200,16 +200,16 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(toy::ReturnOp op,
PatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(toy::ReturnOp op,
PatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
return matchFailure();
return failure();
// We lower "toy.return" directly to "std.return".
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
return success();
}
};
@ -221,7 +221,7 @@ struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
@ -240,7 +240,7 @@ struct TransposeOpLowering : public ConversionPattern {
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
});
return matchSuccess();
return success();
}
};

View File

@ -41,7 +41,7 @@ public:
explicit PrintOpLowering(MLIRContext *context)
: ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
@ -91,7 +91,7 @@ public:
// Notify the rewriter that this operation has been removed.
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
private:

View File

@ -58,7 +58,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
mlir::LogicalResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
@ -68,11 +68,11 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
// Input defined by another transpose? If not, no match.
if (!transposeInputOp)
return matchFailure();
return failure();
// Otherwise, we have a redundant transpose. Use the rewriter.
rewriter.replaceOp(op, {transposeInputOp.getOperand()});
return matchSuccess();
return success();
}
};

View File

@ -53,7 +53,7 @@ class TileAndFuseLinalgOp<
"if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" #
StrJoinInt<sizes>.result # "}, {" # StrJoinInt<operandIndices>.result # "}," #
" \"" # value # "\")))" #
" return matchFailure();">;
" return failure();">;
//===----------------------------------------------------------------------===//
// Linalg tiling patterns.
@ -70,22 +70,22 @@ class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
"if (failed(tileLinalgOpAndSetMarker($_builder, op, {" #
StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
StrJoinInt<permutation>.result # "})))" #
" return matchFailure();">;
" return failure();">;
//===----------------------------------------------------------------------===//
// Linalg to loop patterns.
//===----------------------------------------------------------------------===//
class LinalgOpToLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " #
" return matchFailure();">;
" return failure();">;
class LinalgOpToParallelLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToParallelLoops<" # OpType # ">($_builder, op))) " #
" return matchFailure();">;
" return failure();">;
class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
"if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " #
" return matchFailure();">;
" return failure();">;
//===----------------------------------------------------------------------===//
// Linalg to vector patterns precondition and DRR.

View File

@ -54,10 +54,6 @@ private:
unsigned short representation;
};
/// This is the type returned by a pattern match.
/// TODO: Replace usages with LogicalResult directly.
using PatternMatchResult = LogicalResult;
//===----------------------------------------------------------------------===//
// Pattern class
//===----------------------------------------------------------------------===//
@ -85,20 +81,10 @@ public:
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
virtual PatternMatchResult match(Operation *op) const = 0;
virtual LogicalResult match(Operation *op) const = 0;
virtual ~Pattern() {}
//===--------------------------------------------------------------------===//
// Helper methods to simplify pattern implementations
//===--------------------------------------------------------------------===//
/// Return a result, indicating that no match was found.
PatternMatchResult matchFailure() const { return failure(); }
/// This method indicates that a match was found.
PatternMatchResult matchSuccess() const { return success(); }
protected:
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching.
@ -130,22 +116,19 @@ public:
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). On failure, this
/// returns a None value. On success, it returns a (possibly null)
/// pattern-specific state wrapped in an Optional. This state is passed back
/// into the rewrite function if this match is selected.
PatternMatchResult match(Operation *op) const override;
/// which is the same operation code as getRootKind().
LogicalResult match(Operation *op) const override;
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). If successful, this
/// function will automatically perform the rewrite.
virtual PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
virtual LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
if (succeeded(match(op))) {
rewrite(op, rewriter);
return matchSuccess();
return success();
}
return matchFailure();
return failure();
}
/// Return a list of operations that may be generated when rewriting an
@ -182,11 +165,11 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), rewriter);
}
PatternMatchResult match(Operation *op) const final {
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), rewriter);
}
@ -195,16 +178,16 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual PatternMatchResult match(SourceOp op) const {
virtual LogicalResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
virtual PatternMatchResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const {
virtual LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const {
if (succeeded(match(op))) {
rewrite(op, rewriter);
return matchSuccess();
return success();
}
return matchFailure();
return failure();
}
};

View File

@ -235,18 +235,18 @@ public:
}
/// Hook for derived classes to implement combined matching and rewriting.
virtual PatternMatchResult
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (failed(match(op)))
return matchFailure();
return failure();
rewrite(op, operands, rewriter);
return matchSuccess();
return success();
}
/// Attempt to match and rewrite the IR root at the specified operation.
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final;
private:
using RewritePattern::rewrite;
@ -266,7 +266,7 @@ struct OpConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
@ -282,13 +282,13 @@ struct OpConversionPattern : public ConversionPattern {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual PatternMatchResult
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (failed(match(op)))
return matchFailure();
return failure();
rewrite(op, operands, rewriter);
return matchSuccess();
return success();
}
private:

View File

@ -297,15 +297,15 @@ class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
public:
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineMinOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineMinOp op,
PatternRewriter &rewriter) const override {
Value reduced =
lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands());
if (!reduced)
return matchFailure();
return failure();
rewriter.replaceOp(op, reduced);
return matchSuccess();
return success();
}
};
@ -313,15 +313,15 @@ class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> {
public:
using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineMaxOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineMaxOp op,
PatternRewriter &rewriter) const override {
Value reduced =
lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands());
if (!reduced)
return matchFailure();
return failure();
rewriter.replaceOp(op, reduced);
return matchSuccess();
return success();
}
};
@ -330,10 +330,10 @@ class AffineTerminatorLowering : public OpRewritePattern<AffineTerminatorOp> {
public:
using OpRewritePattern<AffineTerminatorOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineTerminatorOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineTerminatorOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<loop::YieldOp>(op);
return matchSuccess();
return success();
}
};
@ -341,8 +341,8 @@ class AffineForLowering : public OpRewritePattern<AffineForOp> {
public:
using OpRewritePattern<AffineForOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineForOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineForOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lowerBound = lowerAffineLowerBound(op, rewriter);
Value upperBound = lowerAffineUpperBound(op, rewriter);
@ -351,7 +351,7 @@ public:
f.region().getBlocks().clear();
rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end());
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
};
@ -359,8 +359,8 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
public:
using OpRewritePattern<AffineIfOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineIfOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineIfOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
// Now we just have to handle the condition logic.
@ -381,7 +381,7 @@ public:
operandsRef.take_front(numDims),
operandsRef.drop_front(numDims));
if (!affResult)
return matchFailure();
return failure();
auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge;
Value cmpVal =
rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
@ -402,7 +402,7 @@ public:
// Ok, we're done!
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
};
@ -412,15 +412,15 @@ class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
public:
using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineApplyOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineApplyOp op,
PatternRewriter &rewriter) const override {
auto maybeExpandedMap =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
llvm::to_vector<8>(op.getOperands()));
if (!maybeExpandedMap)
return matchFailure();
return failure();
rewriter.replaceOp(op, *maybeExpandedMap);
return matchSuccess();
return success();
}
};
@ -431,18 +431,18 @@ class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
public:
using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineLoadOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineLoadOp op,
PatternRewriter &rewriter) const override {
// Expand affine map from 'affineLoadOp'.
SmallVector<Value, 8> indices(op.getMapOperands());
auto resultOperands =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!resultOperands)
return matchFailure();
return failure();
// Build std.load memref[expandedMap.results].
rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
return matchSuccess();
return success();
}
};
@ -453,20 +453,20 @@ class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
public:
using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffinePrefetchOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffinePrefetchOp op,
PatternRewriter &rewriter) const override {
// Expand affine map from 'affinePrefetchOp'.
SmallVector<Value, 8> indices(op.getMapOperands());
auto resultOperands =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!resultOperands)
return matchFailure();
return failure();
// Build std.prefetch memref[expandedMap.results].
rewriter.replaceOpWithNewOp<PrefetchOp>(
op, op.memref(), *resultOperands, op.isWrite(),
op.localityHint().getZExtValue(), op.isDataCache());
return matchSuccess();
return success();
}
};
@ -477,19 +477,19 @@ class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
public:
using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineStoreOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineStoreOp op,
PatternRewriter &rewriter) const override {
// Expand affine map from 'affineStoreOp'.
SmallVector<Value, 8> indices(op.getMapOperands());
auto maybeExpandedMap =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!maybeExpandedMap)
return matchFailure();
return failure();
// Build std.store valueToStore, memref[expandedMap.results].
rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(),
op.getMemRef(), *maybeExpandedMap);
return matchSuccess();
return success();
}
};
@ -500,8 +500,8 @@ class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
public:
using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineDmaStartOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineDmaStartOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 8> operands(op.getOperands());
auto operandsRef = llvm::makeArrayRef(operands);
@ -510,26 +510,26 @@ public:
rewriter, op.getLoc(), op.getSrcMap(),
operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
if (!maybeExpandedSrcMap)
return matchFailure();
return failure();
// Expand affine map for DMA destination memref.
auto maybeExpandedDstMap = expandAffineMap(
rewriter, op.getLoc(), op.getDstMap(),
operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
if (!maybeExpandedDstMap)
return matchFailure();
return failure();
// Expand affine map for DMA tag memref.
auto maybeExpandedTagMap = expandAffineMap(
rewriter, op.getLoc(), op.getTagMap(),
operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
if (!maybeExpandedTagMap)
return matchFailure();
return failure();
// Build std.dma_start operation with affine map results.
rewriter.replaceOpWithNewOp<DmaStartOp>(
op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
*maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
*maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
return matchSuccess();
return success();
}
};
@ -540,19 +540,19 @@ class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
public:
using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineDmaWaitOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineDmaWaitOp op,
PatternRewriter &rewriter) const override {
// Expand affine map for DMA tag memref.
SmallVector<Value, 8> indices(op.getTagIndices());
auto maybeExpandedTagMap =
expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
if (!maybeExpandedTagMap)
return matchFailure();
return failure();
// Build std.dma_wait operation with affine map results.
rewriter.replaceOpWithNewOp<DmaWaitOp>(
op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
return matchSuccess();
return success();
}
};

View File

@ -46,7 +46,7 @@ public:
indexBitwidth(getIndexBitWidth(lowering_)) {}
// Convert the kernel arguments to an LLVM type, preserve the rest.
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@ -63,7 +63,7 @@ public:
newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
break;
default:
return matchFailure();
return failure();
}
if (indexBitwidth > 32) {
@ -75,7 +75,7 @@ public:
}
rewriter.replaceOp(op, {newOp});
return matchSuccess();
return success();
}
};

View File

@ -34,7 +34,7 @@ public:
lowering_.getDialect()->getContext(), lowering_),
f32Func(f32Func), f64Func(f64Func) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
using LLVM::LLVMFuncOp;
@ -49,13 +49,13 @@ public:
LLVMType funcType = getFunctionType(resultType, operands);
StringRef funcName = getFunctionName(resultType);
if (funcName.empty())
return matchFailure();
return failure();
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp = rewriter.create<LLVM::CallOp>(
op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
rewriter.replaceOp(op, {callOp.getResult(0)});
return matchSuccess();
return success();
}
private:

View File

@ -51,7 +51,7 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
/// !llvm<"{ float, i1 }">
/// %shfl_pred = llvm.extractvalue %shfl[1 : index] :
/// !llvm<"{ float, i1 }">
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
@ -84,7 +84,7 @@ struct GPUShuffleOpLowering : public ConvertToLLVMPattern {
loc, predTy, shfl, rewriter.getIndexArrayAttr(1));
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
return matchSuccess();
return success();
}
};
@ -94,7 +94,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
typeConverter.getDialect()->getContext(),
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.empty() && "func op is not expected to have operands");
@ -219,7 +219,7 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
signatureConversion);
rewriter.eraseOp(gpuFuncOp);
return matchSuccess();
return success();
}
};
@ -229,11 +229,11 @@ struct GPUReturnOpLowering : public ConvertToLLVMPattern {
typeConverter.getDialect()->getContext(),
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
return matchSuccess();
return success();
}
};

View File

@ -26,7 +26,7 @@ class ForOpConversion final : public SPIRVOpLowering<loop::ForOp> {
public:
using SPIRVOpLowering<loop::ForOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -37,7 +37,7 @@ class IfOpConversion final : public SPIRVOpLowering<loop::IfOp> {
public:
using SPIRVOpLowering<loop::IfOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(loop::IfOp IfOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -47,11 +47,11 @@ class TerminatorOpConversion final : public SPIRVOpLowering<loop::YieldOp> {
public:
using SPIRVOpLowering<loop::YieldOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(loop::YieldOp terminatorOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(terminatorOp);
return matchSuccess();
return success();
}
};
@ -62,7 +62,7 @@ class LaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
public:
using SPIRVOpLowering<SourceOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -75,7 +75,7 @@ class WorkGroupSizeConversion : public SPIRVOpLowering<gpu::BlockDimOp> {
public:
using SPIRVOpLowering<gpu::BlockDimOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(gpu::BlockDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -85,7 +85,7 @@ class GPUFuncOpConversion final : public SPIRVOpLowering<gpu::GPUFuncOp> {
public:
using SPIRVOpLowering<gpu::GPUFuncOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
@ -98,7 +98,7 @@ class GPUModuleConversion final : public SPIRVOpLowering<gpu::GPUModuleOp> {
public:
using SPIRVOpLowering<gpu::GPUModuleOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -109,7 +109,7 @@ class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
public:
using SPIRVOpLowering<gpu::ReturnOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -120,7 +120,7 @@ public:
// loop::ForOp.
//===----------------------------------------------------------------------===//
PatternMatchResult
LogicalResult
ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// loop::ForOp can be lowered to the structured control flow represented by
@ -186,14 +186,14 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
rewriter.eraseOp(forOp);
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//
// loop::IfOp.
//===----------------------------------------------------------------------===//
PatternMatchResult
LogicalResult
IfOpConversion::matchAndRewrite(loop::IfOp ifOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// When lowering `loop::IfOp` we explicitly create a selection header block
@ -238,7 +238,7 @@ IfOpConversion::matchAndRewrite(loop::IfOp ifOp, ArrayRef<Value> operands,
elseBlock, ArrayRef<Value>());
rewriter.eraseOp(ifOp);
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//
@ -261,36 +261,36 @@ static Optional<int32_t> getLaunchConfigIndex(Operation *op) {
}
template <typename SourceOp, spirv::BuiltIn builtin>
PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto index = getLaunchConfigIndex(op);
if (!index)
return this->matchFailure();
return failure();
// SPIR-V invocation builtin variables are a vector of type <3xi32>
auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
op, rewriter.getIntegerType(32), spirvBuiltin,
rewriter.getI32ArrayAttr({index.getValue()}));
return this->matchSuccess();
return success();
}
PatternMatchResult WorkGroupSizeConversion::matchAndRewrite(
LogicalResult WorkGroupSizeConversion::matchAndRewrite(
gpu::BlockDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto index = getLaunchConfigIndex(op);
if (!index)
return matchFailure();
return failure();
auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
auto val = workGroupSizeAttr.getValue<int32_t>(index.getValue());
auto convertedType = typeConverter.convertType(op.getResult().getType());
if (!convertedType)
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
op, convertedType, IntegerAttr::get(convertedType, val));
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//
@ -343,11 +343,11 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter,
return newFuncOp;
}
PatternMatchResult GPUFuncOpConversion::matchAndRewrite(
LogicalResult GPUFuncOpConversion::matchAndRewrite(
gpu::GPUFuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!gpu::GPUDialect::isKernel(funcOp))
return matchFailure();
return failure();
SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
for (auto argNum : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
@ -358,22 +358,22 @@ PatternMatchResult GPUFuncOpConversion::matchAndRewrite(
auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
if (!entryPointAttr) {
funcOp.emitRemark("match failure: missing 'spv.entry_point_abi' attribute");
return matchFailure();
return failure();
}
spirv::FuncOp newFuncOp = lowerAsEntryFunction(
funcOp, typeConverter, rewriter, entryPointAttr, argABI);
if (!newFuncOp)
return matchFailure();
return failure();
newFuncOp.removeAttr(Identifier::get(gpu::GPUDialect::getKernelFuncAttrName(),
rewriter.getContext()));
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//
// ModuleOp with gpu.module.
//===----------------------------------------------------------------------===//
PatternMatchResult GPUModuleConversion::matchAndRewrite(
LogicalResult GPUModuleConversion::matchAndRewrite(
gpu::GPUModuleOp moduleOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto spvModule = rewriter.create<spirv::ModuleOp>(
@ -389,21 +389,21 @@ PatternMatchResult GPUModuleConversion::matchAndRewrite(
// legalized later.
spvModuleRegion.back().erase();
rewriter.eraseOp(moduleOp);
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//
// GPU return inside kernel functions to SPIR-V return.
//===----------------------------------------------------------------------===//
PatternMatchResult GPUReturnOpConversion::matchAndRewrite(
LogicalResult GPUReturnOpConversion::matchAndRewrite(
gpu::ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!operands.empty())
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//

View File

@ -130,7 +130,7 @@ public:
explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto rangeOp = cast<RangeOp>(op);
@ -146,7 +146,7 @@ public:
desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1));
desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2));
rewriter.replaceOp(op, desc);
return matchSuccess();
return success();
}
};
@ -160,14 +160,14 @@ public:
: ConvertToLLVMPattern(ReshapeOp::getOperationName(), context,
lowering_) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto reshapeOp = cast<ReshapeOp>(op);
MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>();
if (!dstType.hasStaticShape())
return matchFailure();
return failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
@ -175,7 +175,7 @@ public:
if (failed(res) || llvm::any_of(strides, [](int64_t val) {
return ShapedType::isDynamicStrideOrOffset(val);
}))
return matchFailure();
return failure();
edsc::ScopedContext context(rewriter, op->getLoc());
ReshapeOpOperandAdaptor adaptor(operands);
@ -189,7 +189,7 @@ public:
for (auto en : llvm::enumerate(strides))
desc.setConstantStride(en.index(), en.value());
rewriter.replaceOp(op, {desc});
return matchSuccess();
return success();
}
};
@ -204,7 +204,7 @@ public:
explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext context(rewriter, op->getLoc());
@ -247,7 +247,7 @@ public:
// Corner case, no sizes or strides: early return the descriptor.
if (sliceOp.getShapedType().getRank() == 0)
return rewriter.replaceOp(op, {desc}), matchSuccess();
return rewriter.replaceOp(op, {desc}), success();
Value zero = llvm_constant(
int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
@ -279,7 +279,7 @@ public:
}
rewriter.replaceOp(op, {desc});
return matchSuccess();
return success();
}
};
@ -297,7 +297,7 @@ public:
: ConvertToLLVMPattern(TransposeOp::getOperationName(), context,
lowering_) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Initialize the common boilerplate and alloca at the top of the FuncOp.
@ -308,7 +308,7 @@ public:
auto transposeOp = cast<TransposeOp>(op);
// No permutation, early exit.
if (transposeOp.permutation().isIdentity())
return rewriter.replaceOp(op, {baseDesc}), matchSuccess();
return rewriter.replaceOp(op, {baseDesc}), success();
BaseViewConversionHelper desc(
typeConverter.convertType(transposeOp.getShapedType()));
@ -330,7 +330,7 @@ public:
}
rewriter.replaceOp(op, {desc});
return matchSuccess();
return success();
}
};
@ -340,11 +340,11 @@ public:
explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
return matchSuccess();
return success();
}
};
} // namespace
@ -416,15 +416,15 @@ class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
public:
using OpRewritePattern<LinalgOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
if (!libraryCallName)
return this->matchFailure();
return failure();
rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
return this->matchSuccess();
return success();
}
};
@ -434,22 +434,22 @@ template <> class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
public:
using OpRewritePattern<CopyOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(CopyOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(CopyOp op,
PatternRewriter &rewriter) const override {
auto inputPerm = op.inputPermutation();
if (inputPerm.hasValue() && !inputPerm->isIdentity())
return matchFailure();
return failure();
auto outputPerm = op.outputPermutation();
if (outputPerm.hasValue() && !outputPerm->isIdentity())
return matchFailure();
return failure();
auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
if (!libraryCallName)
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, libraryCallName.getValue(), ArrayRef<Type>{}, op.getOperands());
return matchSuccess();
return success();
}
};
@ -460,12 +460,12 @@ class LinalgOpConversion<IndexedGenericOp>
public:
using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(IndexedGenericOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(IndexedGenericOp op,
PatternRewriter &rewriter) const override {
auto libraryCallName =
getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
if (!libraryCallName)
return this->matchFailure();
return failure();
// TODO(pifon, ntv): Use induction variables values instead of zeros, when
// IndexedGenericOp is tiled.
@ -483,7 +483,7 @@ public:
}
rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
ArrayRef<Type>{}, operands);
return this->matchSuccess();
return success();
}
};
@ -495,8 +495,8 @@ class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
public:
using OpRewritePattern<CopyOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(CopyOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(CopyOp op,
PatternRewriter &rewriter) const override {
Value in = op.input(), out = op.output();
// If either inputPerm or outputPerm are non-identities, insert transposes.
@ -511,10 +511,10 @@ public:
// If nothing was transposed, fail and let the conversion kick in.
if (in == op.input() && out == op.output())
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
return matchSuccess();
return success();
}
};

View File

@ -54,7 +54,7 @@ public:
static Optional<linalg::RegionMatcher::BinaryOpKind>
matchAsPerformingReduction(linalg::GenericOp genericOp);
PatternMatchResult
LogicalResult
matchAndRewrite(linalg::GenericOp genericOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -109,7 +109,7 @@ SingleWorkgroupReduction::matchAsPerformingReduction(
return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);
}
PatternMatchResult SingleWorkgroupReduction::matchAndRewrite(
LogicalResult SingleWorkgroupReduction::matchAndRewrite(
linalg::GenericOp genericOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Operation *op = genericOp.getOperation();
@ -118,19 +118,19 @@ PatternMatchResult SingleWorkgroupReduction::matchAndRewrite(
auto binaryOpKind = matchAsPerformingReduction(genericOp);
if (!binaryOpKind)
return matchFailure();
return failure();
// Query the shader interface for local workgroup size to make sure the
// invocation configuration fits with the input memref's shape.
DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp);
if (!localSize)
return matchFailure();
return failure();
if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
return matchFailure();
return failure();
if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1),
[](const APInt &size) { return !size.isOneValue(); }))
return matchFailure();
return failure();
// TODO(antiagainst): Query the target environment to make sure the current
// workload fits in a local workgroup.
@ -195,7 +195,7 @@ PatternMatchResult SingleWorkgroupReduction::matchAndRewrite(
spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, &rewriter);
rewriter.eraseOp(genericOp);
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//

View File

@ -98,8 +98,8 @@ struct LoopToStandardPass : public OperationPass<LoopToStandardPass> {
struct ForLowering : public OpRewritePattern<ForOp> {
using OpRewritePattern<ForOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override;
};
// Create a CFG subgraph for the loop.if operation (including its "then" and
@ -147,20 +147,20 @@ struct ForLowering : public OpRewritePattern<ForOp> {
struct IfLowering : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const override;
};
struct ParallelLowering : public OpRewritePattern<mlir::loop::ParallelOp> {
using OpRewritePattern<mlir::loop::ParallelOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(mlir::loop::ParallelOp parallelOp,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(mlir::loop::ParallelOp parallelOp,
PatternRewriter &rewriter) const override;
};
} // namespace
PatternMatchResult
ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
// Start by splitting the block containing the 'loop.for' into two parts.
@ -189,7 +189,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
auto step = forOp.step();
auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
if (!stepped)
return matchFailure();
return failure();
SmallVector<Value, 8> loopCarried;
loopCarried.push_back(stepped);
@ -202,7 +202,7 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
Value lowerBound = forOp.lowerBound();
Value upperBound = forOp.upperBound();
if (!lowerBound || !upperBound)
return matchFailure();
return failure();
// The initial values of loop-carried values is obtained from the operands
// of the loop operation.
@ -222,11 +222,11 @@ ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
// The result of the loop operation is the values of the condition block
// arguments except the induction variable on the last iteration.
rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
return matchSuccess();
return success();
}
PatternMatchResult
IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const {
auto loc = ifOp.getLoc();
// Start by splitting the block containing the 'loop.if' into two parts.
@ -265,10 +265,10 @@ IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
// Ok, we're done!
rewriter.eraseOp(ifOp);
return matchSuccess();
return success();
}
PatternMatchResult
LogicalResult
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
PatternRewriter &rewriter) const {
Location loc = parallelOp.getLoc();
@ -344,7 +344,7 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
rewriter.replaceOp(parallelOp, loopResults);
return matchSuccess();
return success();
}
void mlir::populateLoopToStdConversionPatterns(

View File

@ -497,8 +497,8 @@ namespace {
struct ParallelToGpuLaunchLowering : public OpRewritePattern<ParallelOp> {
using OpRewritePattern<ParallelOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(ParallelOp parallelOp,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(ParallelOp parallelOp,
PatternRewriter &rewriter) const override;
};
struct MappingAnnotation {
@ -742,7 +742,7 @@ static LogicalResult processParallelLoop(ParallelOp parallelOp,
/// the actual loop bound. This only works if an static upper bound for the
/// dynamic loop bound can be defived, currently via analyzing `affine.min`
/// operations.
PatternMatchResult
LogicalResult
ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
PatternRewriter &rewriter) const {
// Create a launch operation. We start with bound one for all grid/block
@ -761,7 +761,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
SmallVector<Operation *, 16> worklist;
if (failed(processParallelLoop(parallelOp, launchOp, cloningMap, worklist,
launchBounds, rewriter)))
return matchFailure();
return failure();
// Whether we have seen any side-effects. Reset when leaving an inner scope.
bool seenSideeffects = false;
@ -778,13 +778,13 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
// Before entering a nested scope, make sure there have been no
// sideeffects until now.
if (seenSideeffects)
return matchFailure();
return failure();
// A nested loop.parallel needs insertion of code to compute indices.
// Insert that now. This will also update the worklist with the loops
// body.
if (failed(processParallelLoop(nestedParallel, launchOp, cloningMap,
worklist, launchBounds, rewriter)))
return matchFailure();
return failure();
} else if (op == launchOp.getOperation()) {
// Found our sentinel value. We have finished the operations from one
// nesting level, pop one level back up.
@ -802,7 +802,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
clone->getNumRegions() != 0;
// If we are no longer in the innermost scope, sideeffects are disallowed.
if (seenSideeffects && leftNestingScope)
return matchFailure();
return failure();
}
}
@ -812,7 +812,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
launchOp.setOperand(std::get<0>(bound), std::get<1>(bound));
rewriter.eraseOp(parallelOp);
return matchSuccess();
return success();
}
void mlir::populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns,

View File

@ -946,7 +946,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
bool emitCWrappers)
: FuncOpConversionBase(dialect, converter), emitWrappers(emitCWrappers) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
@ -962,7 +962,7 @@ struct FuncOpConversion : public FuncOpConversionBase {
}
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
private:
@ -976,7 +976,7 @@ private:
struct BarePtrFuncOpConversion : public FuncOpConversionBase {
using FuncOpConversionBase::FuncOpConversionBase;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
@ -990,7 +990,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
if (newFuncOp.getBody().empty()) {
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
// Promote bare pointers from MemRef arguments to a MemRef descriptor struct
@ -1017,7 +1017,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
}
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
};
@ -1109,7 +1109,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
unsigned numResults = op->getNumResults();
@ -1119,7 +1119,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
packedType =
this->typeConverter.packFunctionResults(op->getResultTypes());
if (!packedType)
return this->matchFailure();
return failure();
}
auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
@ -1127,10 +1127,10 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
return rewriter.eraseOp(op), this->matchSuccess();
return rewriter.eraseOp(op), success();
if (numResults == 1)
return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
this->matchSuccess();
success();
// Otherwise, it had been converted to an operation producing a structure.
// Extract individual results from the structure and return them as list.
@ -1143,7 +1143,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
rewriter.getI64ArrayAttr(i)));
}
rewriter.replaceOp(op, results);
return this->matchSuccess();
return success();
}
};
@ -1207,7 +1207,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// Convert the type of the result to an LLVM type, pass operands as is,
// preserve attributes.
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
ValidateOpCount<SourceOp, OpCount>();
@ -1221,7 +1221,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
// Cannot convert ops if their operands are not of LLVM type.
for (Value operand : operands) {
if (!operand || !operand.getType().isa<LLVM::LLVMType>())
return this->matchFailure();
return failure();
}
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
@ -1230,7 +1230,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
auto newOp = rewriter.create<TargetOp>(
op->getLoc(), operands[0].getType(), operands, op->getAttrs());
rewriter.replaceOp(op, newOp.getResult());
return this->matchSuccess();
return success();
}
if (succeeded(HandleMultidimensionalVectors(
@ -1240,8 +1240,8 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
operands, op->getAttrs());
},
rewriter)))
return this->matchSuccess();
return this->matchFailure();
return success();
return failure();
}
};
@ -1381,24 +1381,24 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
: LLVMLegalizationPattern<AllocOp>(dialect_, converter),
useAlloca(useAlloca) {}
PatternMatchResult match(Operation *op) const override {
LogicalResult match(Operation *op) const override {
MemRefType type = cast<AllocOp>(op).getType();
if (isSupportedMemRefType(type))
return matchSuccess();
return success();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(type, strides, offset);
if (failed(successStrides))
return matchFailure();
return failure();
// Dynamic strides are ok if they can be deduced from dynamic sizes (which
// is guaranteed when succeeded(successStrides)). Dynamic offset however can
// never be alloc'ed.
if (offset == MemRefType::getDynamicStrideOrOffset())
return matchFailure();
return failure();
return matchSuccess();
return success();
}
void rewrite(Operation *op, ArrayRef<Value> operands,
@ -1574,7 +1574,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
using Super = CallOpInterfaceLowering<CallOpType>;
using Base = LLVMLegalizationPattern<CallOpType>;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<CallOpType> transformed(operands);
@ -1595,7 +1595,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
if (numResults != 0) {
if (!(packedResult =
this->typeConverter.packFunctionResults(resultTypes)))
return this->matchFailure();
return failure();
}
auto promoted = this->typeConverter.promoteMemRefDescriptors(
@ -1606,7 +1606,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
// If < 2 results, packing did not do anything and we can just return.
if (numResults < 2) {
rewriter.replaceOp(op, newOp.getResults());
return this->matchSuccess();
return success();
}
// Otherwise, it had been converted to an operation producing a structure.
@ -1624,7 +1624,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
}
rewriter.replaceOp(op, results);
return this->matchSuccess();
return success();
}
};
@ -1647,11 +1647,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
: LLVMLegalizationPattern<DeallocOp>(dialect_, converter),
useAlloca(useAlloca) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (useAlloca)
return rewriter.eraseOp(op), matchSuccess();
return rewriter.eraseOp(op), success();
assert(operands.size() == 1 && "dealloc takes one operand");
OperandAdaptor<DeallocOp> transformed(operands);
@ -1673,7 +1673,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
memref.allocatedPtr(rewriter, op->getLoc()));
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
return matchSuccess();
return success();
}
bool useAlloca;
@ -1683,7 +1683,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
using LLVMLegalizationPattern<RsqrtOp>::LLVMLegalizationPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<RsqrtOp> transformed(operands);
@ -1691,7 +1691,7 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
if (!operandType)
return matchFailure();
return failure();
auto loc = op->getLoc();
auto resultType = *op->result_type_begin();
@ -1709,12 +1709,12 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
}
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
return this->matchSuccess();
return success();
}
auto vectorType = resultType.dyn_cast<VectorType>();
if (!vectorType)
return this->matchFailure();
return failure();
if (succeeded(HandleMultidimensionalVectors(
op, operands, typeConverter,
@ -1732,8 +1732,8 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
sqrt);
},
rewriter)))
return this->matchSuccess();
return this->matchFailure();
return success();
return failure();
}
};
@ -1741,7 +1741,7 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
@ -1753,7 +1753,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
if (!operandType)
return matchFailure();
return failure();
std::string functionName;
if (operandType.isFloatTy())
@ -1761,7 +1761,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
else if (operandType.isDoubleTy())
functionName = "tanh";
else
return matchFailure();
return failure();
// Get a reference to the tanh function, inserting it if necessary.
Operation *tanhFunc =
@ -1783,14 +1783,14 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc),
transformed.operand());
return matchSuccess();
return success();
}
};
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
PatternMatchResult match(Operation *op) const override {
LogicalResult match(Operation *op) const override {
auto memRefCastOp = cast<MemRefCastOp>(op);
Type srcType = memRefCastOp.getOperand().getType();
Type dstType = memRefCastOp.getType();
@ -1801,8 +1801,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
MemRefType targetType = memRefCastOp.getType().cast<MemRefType>();
return (isSupportedMemRefType(targetType) &&
isSupportedMemRefType(sourceType))
? matchSuccess()
: matchFailure();
? success()
: failure();
}
// At least one of the operands is unranked type
@ -1812,8 +1812,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
// Unranked to unranked cast is disallowed
return !(srcType.isa<UnrankedMemRefType>() &&
dstType.isa<UnrankedMemRefType>())
? matchSuccess()
: matchFailure();
? success()
: failure();
}
void rewrite(Operation *op, ArrayRef<Value> operands,
@ -1886,17 +1886,17 @@ struct DialectCastOpLowering
: public LLVMLegalizationPattern<LLVM::DialectCastOp> {
using LLVMLegalizationPattern<LLVM::DialectCastOp>::LLVMLegalizationPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto castOp = cast<LLVM::DialectCastOp>(op);
OperandAdaptor<LLVM::DialectCastOp> transformed(operands);
if (transformed.in().getType() !=
typeConverter.convertType(castOp.getType())) {
return matchFailure();
return failure();
}
rewriter.replaceOp(op, transformed.in());
return matchSuccess();
return success();
}
};
@ -1905,7 +1905,7 @@ struct DialectCastOpLowering
struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto dimOp = cast<DimOp>(op);
@ -1922,7 +1922,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
// Use constant for static size.
rewriter.replaceOp(
op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
return matchSuccess();
return success();
}
};
@ -1934,10 +1934,9 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern;
using Base = LoadStoreOpLowering<Derived>;
PatternMatchResult match(Operation *op) const override {
LogicalResult match(Operation *op) const override {
MemRefType type = cast<Derived>(op).getMemRefType();
return isSupportedMemRefType(type) ? this->matchSuccess()
: this->matchFailure();
return isSupportedMemRefType(type) ? success() : failure();
}
// Given subscript indices and array sizes in row-major order,
@ -2010,7 +2009,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
using Base::Base;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loadOp = cast<LoadOp>(op);
@ -2020,7 +2019,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
return matchSuccess();
return success();
}
};
@ -2029,7 +2028,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
using Base::Base;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto type = cast<StoreOp>(op).getMemRefType();
@ -2039,7 +2038,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
transformed.indices(), rewriter, getModule());
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
dataPtr);
return matchSuccess();
return success();
}
};
@ -2048,7 +2047,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
using Base::Base;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto prefetchOp = cast<PrefetchOp>(op);
@ -2072,7 +2071,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
rewriter.replaceOpWithNewOp<LLVM::Prefetch>(op, dataPtr, isWrite,
localityHint, isData);
return matchSuccess();
return success();
}
};
@ -2083,7 +2082,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
IndexCastOpOperandAdaptor transformed(operands);
@ -2104,7 +2103,7 @@ struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
else
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
transformed.in());
return matchSuccess();
return success();
}
};
@ -2118,7 +2117,7 @@ static LLVMPredType convertCmpPredicate(StdPredType pred) {
struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpiOp = cast<CmpIOp>(op);
@ -2130,14 +2129,14 @@ struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
return matchSuccess();
return success();
}
};
struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpfOp = cast<CmpFOp>(op);
@ -2149,7 +2148,7 @@ struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
transformed.lhs(), transformed.rhs());
return matchSuccess();
return success();
}
};
@ -2189,12 +2188,12 @@ struct OneToOneLLVMTerminatorLowering
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, operands, op->getSuccessors(),
op->getAttrs());
return this->matchSuccess();
return success();
}
};
@ -2207,7 +2206,7 @@ struct OneToOneLLVMTerminatorLowering
struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op->getNumOperands();
@ -2216,12 +2215,12 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
if (numArguments == 0) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, ArrayRef<Type>(), ArrayRef<Value>(), op->getAttrs());
return matchSuccess();
return success();
}
if (numArguments == 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, ArrayRef<Type>(), operands.front(), op->getAttrs());
return matchSuccess();
return success();
}
// Otherwise, we need to pack the arguments into an LLVM struct type before
@ -2237,7 +2236,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, ArrayRef<Type>(), packed,
op->getAttrs());
return matchSuccess();
return success();
}
};
@ -2256,13 +2255,13 @@ struct CondBranchOpLowering
struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto splatOp = cast<SplatOp>(op);
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() != 1)
return matchFailure();
return failure();
// First insert it into an undef vector so we can shuffle it.
auto vectorType = typeConverter.convertType(splatOp.getType());
@ -2280,7 +2279,7 @@ struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
// Shuffle the value across the desired number of elements.
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs);
return matchSuccess();
return success();
}
};
@ -2290,14 +2289,14 @@ struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto splatOp = cast<SplatOp>(op);
OperandAdaptor<SplatOp> adaptor(operands);
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() == 1)
return matchFailure();
return failure();
// First insert it into an undef vector so we can shuffle it.
auto loc = op->getLoc();
@ -2305,7 +2304,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
auto llvmArrayTy = vectorTypeInfo.llvmArrayTy;
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
if (!llvmArrayTy || !llvmVectorTy)
return matchFailure();
return failure();
// Construct returned value.
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
@ -2332,7 +2331,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
position);
});
rewriter.replaceOp(op, desc);
return matchSuccess();
return success();
}
};
@ -2344,7 +2343,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
using LLVMLegalizationPattern<SubViewOp>::LLVMLegalizationPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@ -2376,7 +2375,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
auto targetDescTy = typeConverter.convertType(viewMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!sourceElementTy || !targetDescTy)
return matchFailure();
return failure();
// Currently, only rank > 0 and full or no operands are supported. Fail to
// convert otherwise.
@ -2385,22 +2384,22 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
(!dynamicOffsets.empty() && rank != dynamicOffsets.size()) ||
(!dynamicSizes.empty() && rank != dynamicSizes.size()) ||
(!dynamicStrides.empty() && rank != dynamicStrides.size()))
return matchFailure();
return failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
if (failed(successStrides))
return matchFailure();
return failure();
// Fail to convert if neither a dynamic nor static offset is available.
if (dynamicOffsets.empty() &&
offset == MemRefType::getDynamicStrideOrOffset())
return matchFailure();
return failure();
// Create the descriptor.
if (!operands.front().getType().isa<LLVM::LLVMType>())
return matchFailure();
return failure();
MemRefDescriptor sourceMemRef(operands.front());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
@ -2460,7 +2459,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
}
rewriter.replaceOp(op, {targetMemRef});
return matchSuccess();
return success();
}
};
@ -2505,7 +2504,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
return createIndexConstant(rewriter, loc, 1);
}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@ -2520,14 +2519,13 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
typeConverter.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
if (!targetDescTy)
return op->emitWarning("Target descriptor type not converted to LLVM"),
matchFailure();
failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
if (failed(successStrides))
return op->emitWarning("cannot cast to non-strided shape"),
matchFailure();
return op->emitWarning("cannot cast to non-strided shape"), failure();
// Create the descriptor.
MemRefDescriptor sourceMemRef(adaptor.source());
@ -2560,12 +2558,11 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
// Early exit for 0-D corner case.
if (viewMemRefType.getRank() == 0)
return rewriter.replaceOp(op, {targetMemRef}), matchSuccess();
return rewriter.replaceOp(op, {targetMemRef}), success();
// Fields 4 and 5: Update sizes and strides.
if (strides.back() != 1)
return op->emitWarning("cannot cast to non-contiguous shape"),
matchFailure();
return op->emitWarning("cannot cast to non-contiguous shape"), failure();
Value stride = nullptr, nextSize = nullptr;
// Drop the dynamic stride from the operand list, if present.
ArrayRef<Value> sizeOperands(sizeAndOffsetOperands);
@ -2583,7 +2580,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
}
rewriter.replaceOp(op, {targetMemRef});
return matchSuccess();
return success();
}
};
@ -2591,7 +2588,7 @@ struct AssumeAlignmentOpLowering
: public LLVMLegalizationPattern<AssumeAlignmentOp> {
using LLVMLegalizationPattern<AssumeAlignmentOp>::LLVMLegalizationPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<AssumeAlignmentOp> transformed(operands);
@ -2622,7 +2619,7 @@ struct AssumeAlignmentOpLowering
rewriter.create<LLVM::AndOp>(op->getLoc(), ptrValue, mask), zero));
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
};
@ -2657,13 +2654,13 @@ namespace {
struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
using Base::Base;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto atomicOp = cast<AtomicRMWOp>(op);
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (!maybeKind)
return matchFailure();
return failure();
OperandAdaptor<AtomicRMWOp> adaptor(operands);
auto resultType = adaptor.value().getType();
auto memRefType = atomicOp.getMemRefType();
@ -2672,7 +2669,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
op, resultType, *maybeKind, dataPtr, adaptor.value(),
LLVM::AtomicOrdering::acq_rel);
return matchSuccess();
return success();
}
};
@ -2706,13 +2703,13 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
using Base::Base;
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto atomicOp = cast<AtomicRMWOp>(op);
auto maybeKind = matchSimpleAtomicOp(atomicOp);
if (maybeKind)
return matchFailure();
return failure();
LLVM::FCmpPredicate predicate;
switch (atomicOp.kind()) {
@ -2723,7 +2720,7 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
predicate = LLVM::FCmpPredicate::olt;
break;
default:
return matchFailure();
return failure();
}
OperandAdaptor<AtomicRMWOp> adaptor(operands);
@ -2779,7 +2776,7 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
// The 'result' of the atomic_rmw op is the newly loaded value.
rewriter.replaceOp(op, {newLoaded});
return matchSuccess();
return success();
}
};

View File

@ -31,7 +31,7 @@ class ConstantCompositeOpConversion final : public SPIRVOpLowering<ConstantOp> {
public:
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(ConstantOp constCompositeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -45,7 +45,7 @@ class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
public:
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -55,7 +55,7 @@ class CmpFOpConversion final : public SPIRVOpLowering<CmpFOp> {
public:
using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -65,7 +65,7 @@ class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
public:
using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -81,14 +81,14 @@ class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
public:
using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto resultType =
this->typeConverter.convertType(operation.getResult().getType());
rewriter.template replaceOpWithNewOp<SPIRVOp>(
operation, resultType, operands, ArrayRef<NamedAttribute>());
return this->matchSuccess();
return success();
}
};
@ -100,7 +100,7 @@ class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
public:
using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -111,7 +111,7 @@ class ReturnOpConversion final : public SPIRVOpLowering<ReturnOp> {
public:
using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -121,7 +121,7 @@ public:
class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
public:
using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -134,7 +134,7 @@ class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
public:
using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -145,22 +145,22 @@ public:
// ConstantOp with composite type.
//===----------------------------------------------------------------------===//
PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite(
LogicalResult ConstantCompositeOpConversion::matchAndRewrite(
ConstantOp constCompositeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto compositeType =
constCompositeOp.getResult().getType().dyn_cast<RankedTensorType>();
if (!compositeType)
return matchFailure();
return failure();
auto spirvCompositeType = typeConverter.convertType(compositeType);
if (!spirvCompositeType)
return matchFailure();
return failure();
auto linearizedElements =
constCompositeOp.value().dyn_cast<DenseElementsAttr>();
if (!linearizedElements)
return matchFailure();
return failure();
// If composite type has rank greater than one, then perform linearization.
if (compositeType.getRank() > 1) {
@ -171,24 +171,24 @@ PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite(
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
constCompositeOp, spirvCompositeType, linearizedElements);
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//
// ConstantOp with index type.
//===----------------------------------------------------------------------===//
PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
LogicalResult ConstantIndexOpConversion::matchAndRewrite(
ConstantOp constIndexOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!constIndexOp.getResult().getType().isa<IndexType>()) {
return matchFailure();
return failure();
}
// The attribute has index type which is not directly supported in
// SPIR-V. Get the integer value and create a new IntegerAttr.
auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
if (!constAttr) {
return matchFailure();
return failure();
}
// Use the bitwidth set in the value attribute to decide the result type
@ -197,7 +197,7 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
auto constVal = constAttr.getValue();
auto constValType = constAttr.getType().dyn_cast<IndexType>();
if (!constValType) {
return matchFailure();
return failure();
}
auto spirvConstType =
typeConverter.convertType(constIndexOp.getResult().getType());
@ -205,14 +205,14 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
spirvConstVal);
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//
// CmpFOp
//===----------------------------------------------------------------------===//
PatternMatchResult
LogicalResult
CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
CmpFOpOperandAdaptor cmpFOpOperands(operands);
@ -223,7 +223,7 @@ CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
rewriter.replaceOpWithNewOp<spirvOp>(cmpFOp, cmpFOp.getResult().getType(), \
cmpFOpOperands.lhs(), \
cmpFOpOperands.rhs()); \
return matchSuccess();
return success();
// Ordered.
DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
@ -245,14 +245,14 @@ CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
default:
break;
}
return matchFailure();
return failure();
}
//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//
PatternMatchResult
LogicalResult
CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
CmpIOpOperandAdaptor cmpIOpOperands(operands);
@ -263,7 +263,7 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
cmpIOpOperands.lhs(), \
cmpIOpOperands.rhs()); \
return matchSuccess();
return success();
DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
@ -278,14 +278,14 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
#undef DISPATCH
}
return matchFailure();
return failure();
}
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
PatternMatchResult
LogicalResult
LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
LoadOpOperandAdaptor loadOperands(operands);
@ -293,42 +293,42 @@ LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
typeConverter, loadOp.memref().getType().cast<MemRefType>(),
loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
PatternMatchResult
LogicalResult
ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (returnOp.getNumOperands()) {
return matchFailure();
return failure();
}
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
PatternMatchResult
LogicalResult
SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
SelectOpOperandAdaptor selectOperands(operands);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
selectOperands.true_value(),
selectOperands.false_value());
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
PatternMatchResult
LogicalResult
StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
StoreOpOperandAdaptor storeOperands(operands);
@ -338,7 +338,7 @@ StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
rewriter);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
storeOperands.value());
return matchSuccess();
return success();
}
namespace {

View File

@ -26,8 +26,8 @@ class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> {
public:
using OpRewritePattern<LoadOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(LoadOp loadOp,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(LoadOp loadOp,
PatternRewriter &rewriter) const override;
};
/// Merges subview operation with store operation.
@ -35,8 +35,8 @@ class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> {
public:
using OpRewritePattern<StoreOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(StoreOp storeOp,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(StoreOp storeOp,
PatternRewriter &rewriter) const override;
};
} // namespace
@ -107,43 +107,43 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
// Folding SubViewOp and LoadOp.
//===----------------------------------------------------------------------===//
PatternMatchResult
LogicalResult
LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
PatternRewriter &rewriter) const {
auto subViewOp = dyn_cast_or_null<SubViewOp>(loadOp.memref().getDefiningOp());
if (!subViewOp) {
return matchFailure();
return failure();
}
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
loadOp.indices(), sourceIndices)))
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
sourceIndices);
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//
// Folding SubViewOp and StoreOp.
//===----------------------------------------------------------------------===//
PatternMatchResult
LogicalResult
StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
PatternRewriter &rewriter) const {
auto subViewOp =
dyn_cast_or_null<SubViewOp>(storeOp.memref().getDefiningOp());
if (!subViewOp) {
return matchFailure();
return failure();
}
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
storeOp.indices(), sourceIndices)))
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
subViewOp.source(), sourceIndices);
return matchSuccess();
return success();
}
//===----------------------------------------------------------------------===//

View File

@ -133,13 +133,13 @@ public:
: ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto broadcastOp = cast<vector::BroadcastOp>(op);
VectorType dstVectorType = broadcastOp.getVectorType();
if (typeConverter.convertType(dstVectorType) == nullptr)
return matchFailure();
return failure();
// Rewrite when the full vector type can be lowered (which
// implies all 'reduced' types can be lowered too).
auto adaptor = vector::BroadcastOpOperandAdaptor(operands);
@ -149,7 +149,7 @@ public:
op, expandRanks(adaptor.source(), // source value to be expanded
op->getLoc(), // location of original broadcast
srcVectorType, dstVectorType, rewriter));
return matchSuccess();
return success();
}
private:
@ -284,7 +284,7 @@ public:
: ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto matmulOp = cast<vector::MatmulOp>(op);
@ -293,7 +293,7 @@ public:
op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
matmulOp.rhs_columns());
return matchSuccess();
return success();
}
};
@ -304,7 +304,7 @@ public:
: ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto reductionOp = cast<vector::ReductionOp>(op);
@ -335,8 +335,8 @@ public:
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
op, llvmType, operands[0]);
else
return matchFailure();
return matchSuccess();
return failure();
return success();
} else if (eltType.isF32() || eltType.isF64()) {
// Floating-point reductions: add/mul/min/max
@ -364,10 +364,10 @@ public:
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
op, llvmType, operands[0]);
else
return matchFailure();
return matchSuccess();
return failure();
return success();
}
return matchFailure();
return failure();
}
};
@ -378,7 +378,7 @@ public:
: ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@ -392,7 +392,7 @@ public:
// Bail if result type cannot be lowered.
if (!llvmType)
return matchFailure();
return failure();
// Get rank and dimension sizes.
int64_t rank = vectorType.getRank();
@ -406,7 +406,7 @@ public:
Value shuffle = rewriter.create<LLVM::ShuffleVectorOp>(
loc, adaptor.v1(), adaptor.v2(), maskArrayAttr);
rewriter.replaceOp(op, shuffle);
return matchSuccess();
return success();
}
// For all other cases, insert the individual values individually.
@ -425,7 +425,7 @@ public:
llvmType, rank, insPos++);
}
rewriter.replaceOp(op, insert);
return matchSuccess();
return success();
}
};
@ -436,7 +436,7 @@ public:
: ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(),
context, typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
@ -446,11 +446,11 @@ public:
// Bail if result type cannot be lowered.
if (!llvmType)
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
op, llvmType, adaptor.vector(), adaptor.position());
return matchSuccess();
return success();
}
};
@ -461,7 +461,7 @@ public:
: ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@ -474,14 +474,14 @@ public:
// Bail if result type cannot be lowered.
if (!llvmResultType)
return matchFailure();
return failure();
// One-shot extraction of vector from array (only requires extractvalue).
if (resultType.isa<VectorType>()) {
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmResultType, adaptor.vector(), positionArrayAttr);
rewriter.replaceOp(op, extracted);
return matchSuccess();
return success();
}
// Potential extraction of 1-D vector from array.
@ -505,7 +505,7 @@ public:
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(op, extracted);
return matchSuccess();
return success();
}
};
@ -530,17 +530,17 @@ public:
: ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::FMAOpOperandAdaptor(operands);
vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
VectorType vType = fmaOp.getVectorType();
if (vType.getRank() != 1)
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
adaptor.acc());
return matchSuccess();
return success();
}
};
@ -551,7 +551,7 @@ public:
: ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(),
context, typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
@ -561,11 +561,11 @@ public:
// Bail if result type cannot be lowered.
if (!llvmType)
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
return matchSuccess();
return success();
}
};
@ -576,7 +576,7 @@ public:
: ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@ -589,7 +589,7 @@ public:
// Bail if result type cannot be lowered.
if (!llvmResultType)
return matchFailure();
return failure();
// One-shot insertion of a vector into an array (only requires insertvalue).
if (sourceType.isa<VectorType>()) {
@ -597,7 +597,7 @@ public:
loc, llvmResultType, adaptor.dest(), adaptor.source(),
positionArrayAttr);
rewriter.replaceOp(op, inserted);
return matchSuccess();
return success();
}
// Potential extraction of 1-D vector from array.
@ -632,7 +632,7 @@ public:
}
rewriter.replaceOp(op, inserted);
return matchSuccess();
return success();
}
};
@ -661,11 +661,11 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
public:
using OpRewritePattern<FMAOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(FMAOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(FMAOp op,
PatternRewriter &rewriter) const override {
auto vType = op.getVectorType();
if (vType.getRank() < 2)
return matchFailure();
return failure();
auto loc = op.getLoc();
auto elemType = vType.getElementType();
@ -680,7 +680,7 @@ public:
desc = rewriter.create<InsertOp>(loc, fma, desc, i);
}
rewriter.replaceOp(op, desc);
return matchSuccess();
return success();
}
};
@ -704,19 +704,19 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.offsets().getValue().empty())
return matchFailure();
return failure();
auto loc = op.getLoc();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff == 0)
return matchFailure();
return failure();
int64_t rankRest = dstType.getRank() - rankDiff;
// Extract / insert the subvector of matching rank and InsertStridedSlice
@ -735,7 +735,7 @@ public:
op, stridedSliceInnerOp.getResult(), op.dest(),
getI64SubArray(op.offsets(), /*dropFront=*/0,
/*dropFront=*/rankRest));
return matchSuccess();
return success();
}
};
@ -753,22 +753,22 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.offsets().getValue().empty())
return matchFailure();
return failure();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff != 0)
return matchFailure();
return failure();
if (srcType == dstType) {
rewriter.replaceOp(op, op.source());
return matchSuccess();
return success();
}
int64_t offset =
@ -813,7 +813,7 @@ public:
}
rewriter.replaceOp(op, res);
return matchSuccess();
return success();
}
};
@ -824,7 +824,7 @@ public:
: ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
@ -837,18 +837,18 @@ public:
// Only static shape casts supported atm.
if (!sourceMemRefType.hasStaticShape() ||
!targetMemRefType.hasStaticShape())
return matchFailure();
return failure();
auto llvmSourceDescriptorTy =
operands[0].getType().dyn_cast<LLVM::LLVMType>();
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
return matchFailure();
return failure();
MemRefDescriptor sourceMemRef(operands[0]);
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return matchFailure();
return failure();
int64_t offset;
SmallVector<int64_t, 4> strides;
@ -866,7 +866,7 @@ public:
}
// Only contiguous source tensors supported atm.
if (failed(successStrides) || !isContiguous)
return matchFailure();
return failure();
auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
@ -901,7 +901,7 @@ public:
}
rewriter.replaceOp(op, {desc});
return matchSuccess();
return success();
}
};
@ -924,7 +924,7 @@ public:
//
// TODO(ajcbik): rely solely on libc in future? something else?
//
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto printOp = cast<vector::PrintOp>(op);
@ -932,7 +932,7 @@ public:
Type printType = printOp.getPrintType();
if (typeConverter.convertType(printType) == nullptr)
return matchFailure();
return failure();
// Make sure element type has runtime support (currently just Float/Double).
VectorType vectorType = printType.dyn_cast<VectorType>();
@ -948,13 +948,13 @@ public:
else if (eltType.isF64())
printer = getPrintDouble(op);
else
return matchFailure();
return failure();
// Unroll vector into elementary print calls.
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
emitCall(rewriter, op->getLoc(), getPrintNewline(op));
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
private:
@ -1047,8 +1047,8 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
public:
using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(StridedSliceOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(StridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getResult().getType().cast<VectorType>();
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
@ -1089,7 +1089,7 @@ public:
res = insertOne(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, {res});
return matchSuccess();
return success();
}
};

View File

@ -198,8 +198,8 @@ struct VectorTransferRewriter : public RewritePattern {
}
/// Performs the rewrite.
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
};
/// Lowers TransferReadOp into a combination of:
@ -246,7 +246,7 @@ struct VectorTransferRewriter : public RewritePattern {
/// Performs the rewrite.
template <>
PatternMatchResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
using namespace mlir::edsc::op;
@ -282,7 +282,7 @@ PatternMatchResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
// 3. Propagate.
rewriter.replaceOp(op, vectorValue.getValue());
return matchSuccess();
return success();
}
/// Lowers TransferWriteOp into a combination of:
@ -304,7 +304,7 @@ PatternMatchResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
/// TODO(ntv): implement alternatives to clipping.
/// TODO(ntv): support non-data-parallel operations.
template <>
PatternMatchResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
using namespace edsc::op;
@ -340,7 +340,7 @@ PatternMatchResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
(std_dealloc(tmp)); // vexing parse...
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
} // namespace

View File

@ -727,8 +727,8 @@ struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
AffineMap map, ArrayRef<Value> mapOperands) const;
PatternMatchResult matchAndRewrite(AffineOpTy affineOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineOpTy affineOp,
PatternRewriter &rewriter) const override {
static_assert(std::is_same<AffineOpTy, AffineLoadOp>::value ||
std::is_same<AffineOpTy, AffinePrefetchOp>::value ||
std::is_same<AffineOpTy, AffineStoreOp>::value ||
@ -743,10 +743,10 @@ struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
composeAffineMapAndOperands(&map, &resultOperands);
if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
resultOperands.begin()))
return this->matchFailure();
return failure();
replaceAffineOp(rewriter, affineOp, map, resultOperands);
return this->matchSuccess();
return success();
}
};
@ -1405,13 +1405,13 @@ namespace {
struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
using OpRewritePattern<AffineForOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineForOp forOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AffineForOp forOp,
PatternRewriter &rewriter) const override {
// Check that the body only contains a terminator.
if (!has_single_element(*forOp.getBody()))
return matchFailure();
return failure();
rewriter.eraseOp(forOp);
return matchSuccess();
return success();
}
};
} // end anonymous namespace

View File

@ -111,8 +111,8 @@ namespace {
struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(DequantizeCastOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(DequantizeCastOp op,
PatternRewriter &rewriter) const override {
Type inputType = op.arg().getType();
Type outputType = op.getResult().getType();
@ -121,16 +121,16 @@ struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
Type expressedOutputType = inputElementType.castToExpressedType(inputType);
if (expressedOutputType != outputType) {
// Not a valid uniform cast.
return matchFailure();
return failure();
}
Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
if (!dequantizedValue) {
return matchFailure();
return failure();
}
rewriter.replaceOp(op, dequantizedValue);
return matchSuccess();
return success();
}
};
@ -313,40 +313,40 @@ namespace {
struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(RealAddEwOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(RealAddEwOp op,
PatternRewriter &rewriter) const override {
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
return matchFailure();
return failure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
return matchSuccess();
return success();
}
return matchFailure();
return failure();
}
};
struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(RealMulEwOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(RealMulEwOp op,
PatternRewriter &rewriter) const override {
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
return matchFailure();
return failure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
return matchSuccess();
return success();
}
return matchFailure();
return failure();
}
};

View File

@ -380,8 +380,8 @@ struct GpuAllReduceConversion : public RewritePattern {
explicit GpuAllReduceConversion(MLIRContext *context)
: RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto funcOp = cast<gpu::GPUFuncOp>(op);
auto callback = [&](gpu::AllReduceOp reduceOp) {
GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
@ -391,7 +391,7 @@ struct GpuAllReduceConversion : public RewritePattern {
};
while (funcOp.walk(callback).wasInterrupted()) {
}
return matchSuccess();
return success();
}
};
} // namespace

View File

@ -534,10 +534,10 @@ namespace {
struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
if (!op.hasTensorSemantics())
return matchFailure();
return failure();
// Find the first operand that is defined by another generic op on tensors.
for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) {
@ -551,9 +551,9 @@ struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> {
if (!fusedOp)
continue;
rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults());
return matchSuccess();
return success();
}
return matchFailure();
return failure();
}
};

View File

@ -531,13 +531,13 @@ public:
explicit LinalgRewritePattern(MLIRContext *context)
: RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
using Impl = LinalgOpToLoopsImpl<LoopType, IndexedValueType, ConcreteOp>;
if (failed(Impl::doit(op, rewriter)))
return matchFailure();
return failure();
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
};
@ -595,26 +595,26 @@ struct FoldAffineOp : public RewritePattern {
FoldAffineOp(MLIRContext *context)
: RewritePattern(AffineApplyOp::getOperationName(), 0, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
AffineApplyOp affineApplyOp = cast<AffineApplyOp>(op);
auto map = affineApplyOp.getAffineMap();
if (map.getNumResults() != 1 || map.getNumInputs() > 1)
return matchFailure();
return failure();
AffineExpr expr = map.getResult(0);
if (map.getNumInputs() == 0) {
if (auto val = expr.dyn_cast<AffineConstantExpr>()) {
rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, val.getValue());
return matchSuccess();
return success();
}
return matchFailure();
return failure();
}
if (expr.dyn_cast<AffineDimExpr>() || expr.dyn_cast<AffineSymbolExpr>()) {
rewriter.replaceOp(op, op->getOperand(0));
return matchSuccess();
return success();
}
return matchFailure();
return failure();
}
};
} // namespace

View File

@ -30,8 +30,8 @@ public:
struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(QuantizeCastOp qbarrier,
PatternRewriter &rewriter) const override;
};
} // end anonymous namespace
@ -39,14 +39,14 @@ struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
/// quantized and the operand type is quantizable.
PatternMatchResult
LogicalResult
QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
PatternRewriter &rewriter) const {
Attribute value;
// Is the operand a constant?
if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
return matchFailure();
return failure();
}
// Does the qbarrier convert to a quantized type. This will not be true
@ -56,10 +56,10 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
QuantizedType quantizedElementType =
QuantizedType::getQuantizedElementType(qbarrierResultType);
if (!quantizedElementType) {
return matchFailure();
return failure();
}
if (!QuantizedType::castToStorageType(qbarrierResultType)) {
return matchFailure();
return failure();
}
// Is the operand type compatible with the expressed type of the quantized
@ -67,20 +67,20 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
// from and to a quantized type).
if (!quantizedElementType.isCompatibleExpressedType(
qbarrier.arg().getType())) {
return matchFailure();
return failure();
}
// Is the constant value a type expressed in a way that we support?
if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
!value.isa<SparseElementsAttr>()) {
return matchFailure();
return failure();
}
Type newConstValueType;
auto newConstValue =
quantizeAttr(value, quantizedElementType, newConstValueType);
if (!newConstValue) {
return matchFailure();
return failure();
}
// When creating the new const op, use a fused location that combines the
@ -92,7 +92,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
rewriter.replaceOpWithNewOp<StorageCastOp>(qbarrier, qbarrier.getType(),
newConstOp);
return matchSuccess();
return success();
}
void ConvertConstPass::runOnFunction() {

View File

@ -35,16 +35,16 @@ public:
FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
: OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
PatternMatchResult matchAndRewrite(FakeQuantOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(FakeQuantOp op,
PatternRewriter &rewriter) const override {
// TODO: If this pattern comes up more frequently, consider adding core
// support for failable rewrites.
if (failableRewrite(op, rewriter)) {
*hadFailure = true;
return Pattern::matchFailure();
return failure();
}
return Pattern::matchSuccess();
return success();
}
private:

View File

@ -88,13 +88,13 @@ struct CombineChainedAccessChain
: public OpRewritePattern<spirv::AccessChainOp> {
using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
PatternRewriter &rewriter) const override {
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
accessChainOp.base_ptr().getDefiningOp());
if (!parentAccessChainOp) {
return matchFailure();
return failure();
}
// Combine indices.
@ -105,7 +105,7 @@ struct CombineChainedAccessChain
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
accessChainOp, parentAccessChainOp.base_ptr(), indices);
return matchSuccess();
return success();
}
};
} // end anonymous namespace
@ -291,24 +291,24 @@ struct ConvertSelectionOpToSelect
: public OpRewritePattern<spirv::SelectionOp> {
using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
PatternRewriter &rewriter) const override {
auto *op = selectionOp.getOperation();
auto &body = op->getRegion(0);
// Verifier allows an empty region for `spv.selection`.
if (body.empty()) {
return matchFailure();
return failure();
}
// Check that region consists of 4 blocks:
// header block, `true` block, `false` block and merge block.
if (std::distance(body.begin(), body.end()) != 4) {
return matchFailure();
return failure();
}
auto *headerBlock = selectionOp.getHeaderBlock();
if (!onlyContainsBranchConditionalOp(headerBlock)) {
return matchFailure();
return failure();
}
auto brConditionalOp =
@ -319,7 +319,7 @@ struct ConvertSelectionOpToSelect
auto *mergeBlock = selectionOp.getMergeBlock();
if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
return matchFailure();
return failure();
auto trueValue = getSrcValue(trueBlock);
auto falseValue = getSrcValue(falseBlock);
@ -335,7 +335,7 @@ struct ConvertSelectionOpToSelect
// `spv.selection` is not needed anymore.
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
private:
@ -345,9 +345,8 @@ private:
// 2. Each `spv.Store` uses the same pointer and the same memory attributes.
// 3. A control flow goes into the given merge block from the given
// conditional blocks.
PatternMatchResult canCanonicalizeSelection(Block *trueBlock,
Block *falseBlock,
Block *mergeBlock) const;
LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
Block *mergeBlock) const;
bool onlyContainsBranchConditionalOp(Block *block) const {
return std::next(block->begin()) == block->end() &&
@ -382,12 +381,12 @@ private:
}
};
PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
// Each block must consists of 2 operations.
if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
(std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
return matchFailure();
return failure();
}
auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
@ -399,7 +398,7 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
!falseBrBranchOp) {
return matchFailure();
return failure();
}
// Check that each `spv.Store` uses the same pointer, memory access
@ -407,15 +406,15 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
!isValidType(trueBrStoreOp.value().getType())) {
return matchFailure();
return failure();
}
if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) ||
(falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) {
return matchFailure();
return failure();
}
return matchSuccess();
return success();
}
} // end anonymous namespace

View File

@ -177,25 +177,25 @@ class FuncOpConversion final : public SPIRVOpLowering<FuncOp> {
public:
using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
PatternMatchResult
LogicalResult
FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto fnType = funcOp.getType();
// TODO(antiagainst): support converting functions with one result.
if (fnType.getNumResults())
return matchFailure();
return failure();
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
for (auto argType : enumerate(funcOp.getType().getInputs())) {
auto convertedType = typeConverter.convertType(argType.value());
if (!convertedType)
return matchFailure();
return failure();
signatureConverter.addInputs(argType.index(), convertedType);
}
@ -216,7 +216,7 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
newFuncOp.end());
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
rewriter.eraseOp(funcOp);
return matchSuccess();
return success();
}
void mlir::populateBuiltinFuncToSPIRVPatterns(

View File

@ -27,8 +27,8 @@ class SPIRVGlobalVariableOpLayoutInfoDecoration
public:
using OpRewritePattern<spirv::GlobalVariableOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(spirv::GlobalVariableOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(spirv::GlobalVariableOp op,
PatternRewriter &rewriter) const override {
spirv::StructType::LayoutInfo structSize = 0;
VulkanLayoutUtils::Size structAlignment = 1;
SmallVector<NamedAttribute, 4> globalVarAttrs;
@ -50,7 +50,7 @@ public:
rewriter.replaceOpWithNewOp<spirv::GlobalVariableOp>(
op, TypeAttr::get(decoratedType), globalVarAttrs);
return matchSuccess();
return success();
}
};
@ -59,15 +59,15 @@ class SPIRVAddressOfOpLayoutInfoDecoration
public:
using OpRewritePattern<spirv::AddressOfOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(spirv::AddressOfOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(spirv::AddressOfOp op,
PatternRewriter &rewriter) const override {
auto spirvModule = op.getParentOfType<spirv::ModuleOp>();
auto varName = op.variable();
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
op, varOp.type(), rewriter.getSymbolRefAttr(varName));
return matchSuccess();
return success();
}
};
} // namespace

View File

@ -138,7 +138,7 @@ namespace {
class ProcessInterfaceVarABI final : public SPIRVOpLowering<spirv::FuncOp> {
public:
using SPIRVOpLowering<spirv::FuncOp>::SPIRVOpLowering;
PatternMatchResult
LogicalResult
matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
@ -151,13 +151,13 @@ private:
};
} // namespace
PatternMatchResult ProcessInterfaceVarABI::matchAndRewrite(
LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
spirv::FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!funcOp.getAttrOfType<spirv::EntryPointABIAttr>(
spirv::getEntryPointABIAttrName())) {
// TODO(ravishankarm) : Non-entry point functions are not handled.
return matchFailure();
return failure();
}
TypeConverter::SignatureConversion signatureConverter(
funcOp.getType().getNumInputs());
@ -171,12 +171,12 @@ PatternMatchResult ProcessInterfaceVarABI::matchAndRewrite(
// to pass around scalar/vector values and return a scalar/vector. For now
// non-entry point functions are not handled in this ABI lowering and will
// produce an error.
return matchFailure();
return failure();
}
auto var =
createGlobalVariableForArg(funcOp, rewriter, argType.index(), abiInfo);
if (!var) {
return matchFailure();
return failure();
}
OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
@ -207,7 +207,7 @@ PatternMatchResult ProcessInterfaceVarABI::matchAndRewrite(
signatureConverter.getConvertedTypes(), llvm::None));
rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
});
return matchSuccess();
return success();
}
void LowerABIAttributesPass::runOnOperation() {

View File

@ -313,14 +313,14 @@ namespace {
struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
using OpRewritePattern<AllocOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AllocOp alloc,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AllocOp alloc,
PatternRewriter &rewriter) const override {
// Check to see if any dimensions operands are constants. If so, we can
// substitute and drop them.
if (llvm::none_of(alloc.getOperands(), [](Value operand) {
return matchPattern(operand, m_ConstantIndex());
}))
return matchFailure();
return failure();
auto memrefType = alloc.getType();
@ -364,7 +364,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
alloc.getType());
rewriter.replaceOp(alloc, {resultCast});
return matchSuccess();
return success();
}
};
@ -373,13 +373,13 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
using OpRewritePattern<AllocOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AllocOp alloc,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AllocOp alloc,
PatternRewriter &rewriter) const override {
if (alloc.use_empty()) {
rewriter.eraseOp(alloc);
return matchSuccess();
return success();
}
return matchFailure();
return failure();
}
};
} // end anonymous namespace.
@ -461,18 +461,18 @@ namespace {
struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
using OpRewritePattern<BranchOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(BranchOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(BranchOp op,
PatternRewriter &rewriter) const override {
// Check that the successor block has a single predecessor.
Block *succ = op.getDest();
Block *opParent = op.getOperation()->getBlock();
if (succ == opParent || !has_single_element(succ->getPredecessors()))
return matchFailure();
return failure();
// Merge the successor into the current block and erase the branch.
rewriter.mergeBlocks(succ, opParent, op.getOperands());
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
};
} // end anonymous namespace.
@ -545,18 +545,18 @@ struct SimplifyIndirectCallWithKnownCallee
: public OpRewritePattern<CallIndirectOp> {
using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(CallIndirectOp indirectCall,
PatternRewriter &rewriter) const override {
// Check that the callee is a constant callee.
SymbolRefAttr calledFn;
if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
return matchFailure();
return failure();
// Replace with a direct call.
rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
indirectCall.getResultTypes(),
indirectCall.getArgOperands());
return matchSuccess();
return success();
}
};
} // end anonymous namespace.
@ -733,20 +733,20 @@ namespace {
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
if (matchPattern(condbr.getCondition(), m_NonZero())) {
// True branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
condbr.getTrueOperands());
return matchSuccess();
return success();
} else if (matchPattern(condbr.getCondition(), m_Zero())) {
// False branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
condbr.getFalseOperands());
return matchSuccess();
return success();
}
return matchFailure();
return failure();
}
};
} // end anonymous namespace.
@ -958,21 +958,21 @@ namespace {
struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
using OpRewritePattern<DeallocOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(DeallocOp dealloc,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(DeallocOp dealloc,
PatternRewriter &rewriter) const override {
// Check that the memref operand's defining operation is an AllocOp.
Value memref = dealloc.memref();
if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
return matchFailure();
return failure();
// Check that all of the uses of the AllocOp are other DeallocOps.
for (auto *user : memref.getUsers())
if (!isa<DeallocOp>(user))
return matchFailure();
return failure();
// Erase the dealloc operation.
rewriter.eraseOp(dealloc);
return matchSuccess();
return success();
}
};
} // end anonymous namespace.
@ -2003,8 +2003,8 @@ class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> {
public:
using OpRewritePattern<SubViewOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
MemRefType subViewType = subViewOp.getType();
// Follow all or nothing approach for shapes for now. If all the operands
// for sizes are constants then fold it into the type of the result memref.
@ -2012,7 +2012,7 @@ public:
llvm::any_of(subViewOp.sizes(), [](Value operand) {
return !matchPattern(operand, m_ConstantIndex());
})) {
return matchFailure();
return failure();
}
SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
for (auto size : llvm::enumerate(subViewOp.sizes())) {
@ -2028,7 +2028,7 @@ public:
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
subViewOp.getType());
return matchSuccess();
return success();
}
};
@ -2037,10 +2037,10 @@ class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> {
public:
using OpRewritePattern<SubViewOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
if (subViewOp.getNumStrides() == 0) {
return matchFailure();
return failure();
}
// Follow all or nothing approach for strides for now. If all the operands
// for strides are constants then fold it into the strides of the result
@ -2056,7 +2056,7 @@ public:
llvm::any_of(subViewOp.strides(), [](Value stride) {
return !matchPattern(stride, m_ConstantIndex());
})) {
return matchFailure();
return failure();
}
SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
@ -2077,7 +2077,7 @@ public:
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
subViewOp.getType());
return matchSuccess();
return success();
}
};
@ -2086,10 +2086,10 @@ class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
public:
using OpRewritePattern<SubViewOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
if (subViewOp.getNumOffsets() == 0) {
return matchFailure();
return failure();
}
// Follow all or nothing approach for offsets for now. If all the operands
// for offsets are constants then fold it into the offset of the result
@ -2106,7 +2106,7 @@ public:
llvm::any_of(subViewOp.offsets(), [](Value stride) {
return !matchPattern(stride, m_ConstantIndex());
})) {
return matchFailure();
return failure();
}
auto staticOffset = baseOffset;
@ -2128,7 +2128,7 @@ public:
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
subViewOp.getType());
return matchSuccess();
return success();
}
};
@ -2347,18 +2347,18 @@ namespace {
struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
using OpRewritePattern<ViewOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(ViewOp viewOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(ViewOp viewOp,
PatternRewriter &rewriter) const override {
// Return if none of the operands are constants.
if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
return matchPattern(operand, m_ConstantIndex());
}))
return matchFailure();
return failure();
// Get result memref type.
auto memrefType = viewOp.getType();
if (memrefType.getAffineMaps().size() > 1)
return matchFailure();
return failure();
auto map = memrefType.getAffineMaps().empty()
? AffineMap::getMultiDimIdentityMap(memrefType.getRank(),
rewriter.getContext())
@ -2368,7 +2368,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
int64_t oldOffset;
SmallVector<int64_t, 4> oldStrides;
if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
return matchFailure();
return failure();
SmallVector<Value, 4> newOperands;
@ -2444,27 +2444,27 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
// Insert a cast so we have the same type as the old memref type.
rewriter.replaceOpWithNewOp<MemRefCastOp>(viewOp, newViewOp,
viewOp.getType());
return matchSuccess();
return success();
}
};
struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
using OpRewritePattern<ViewOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(ViewOp viewOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(ViewOp viewOp,
PatternRewriter &rewriter) const override {
Value memrefOperand = viewOp.getOperand(0);
MemRefCastOp memrefCastOp =
dyn_cast_or_null<MemRefCastOp>(memrefOperand.getDefiningOp());
if (!memrefCastOp)
return matchFailure();
return failure();
Value allocOperand = memrefCastOp.getOperand();
AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
if (!allocOp)
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
viewOp.operands());
return matchSuccess();
return success();
}
};

View File

@ -1145,18 +1145,18 @@ class StridedSliceConstantMaskFolder final
public:
using OpRewritePattern<StridedSliceOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(StridedSliceOp stridedSliceOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(StridedSliceOp stridedSliceOp,
PatternRewriter &rewriter) const override {
// Return if 'stridedSliceOp' operand is not defined by a ConstantMaskOp.
auto defOp = stridedSliceOp.vector().getDefiningOp();
auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
if (!constantMaskOp)
return matchFailure();
return failure();
// Return if 'stridedSliceOp' has non-unit strides.
if (llvm::any_of(stridedSliceOp.strides(), [](Attribute attr) {
return attr.cast<IntegerAttr>().getInt() != 1;
}))
return matchFailure();
return failure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes);
@ -1187,7 +1187,7 @@ public:
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
stridedSliceOp, stridedSliceOp.getResult().getType(),
vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
return matchSuccess();
return success();
}
};
@ -1619,14 +1619,14 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
public:
using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
// Return if any of 'createMaskOp' operands are not defined by a constant.
auto is_not_def_by_constant = [](Value operand) {
return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
};
if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
return matchFailure();
return failure();
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
for (auto operand : createMaskOp.operands()) {
@ -1637,7 +1637,7 @@ public:
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
createMaskOp, createMaskOp.getResult().getType(),
vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
return matchSuccess();
return success();
}
};

View File

@ -545,18 +545,18 @@ namespace {
struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(vector::TransferReadOp xferReadOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::TransferReadOp xferReadOp,
PatternRewriter &rewriter) const override {
// TODO(andydavis, ntv) Support splitting TransferReadOp with non-identity
// permutation maps. Repurpose code from MaterializeVectors transformation.
if (!isIdentitySuffix(xferReadOp.permutation_map()))
return matchFailure();
return failure();
// Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
Value xferReadResult = xferReadOp.getResult();
auto extractSlicesOp =
dyn_cast<vector::ExtractSlicesOp>(*xferReadResult.getUsers().begin());
if (!xferReadResult.hasOneUse() || !extractSlicesOp)
return matchFailure();
return failure();
// Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
auto sourceVectorType = extractSlicesOp.getSourceVectorType();
@ -593,7 +593,7 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>(
xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(),
extractSlicesOp.strides());
return matchSuccess();
return success();
}
};
@ -601,23 +601,23 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
PatternRewriter &rewriter) const override {
// TODO(andydavis, ntv) Support splitting TransferWriteOp with non-identity
// permutation maps. Repurpose code from MaterializeVectors transformation.
if (!isIdentitySuffix(xferWriteOp.permutation_map()))
return matchFailure();
return failure();
// Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
if (!insertSlicesOp)
return matchFailure();
return failure();
// Get TupleOp operand of 'insertSlicesOp'.
auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
insertSlicesOp.vectors().getDefiningOp());
if (!tupleOp)
return matchFailure();
return failure();
// Get 'sizes' and 'strides' parameters from InsertSlicesOp user.
auto sourceTupleType = insertSlicesOp.getSourceTupleType();
@ -644,7 +644,7 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
// Erase old 'xferWriteOp'.
rewriter.eraseOp(xferWriteOp);
return matchSuccess();
return success();
}
};
@ -653,15 +653,15 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
// Check if 'shapeCastOp' has tuple source/result type.
auto sourceTupleType =
shapeCastOp.source().getType().dyn_cast_or_null<TupleType>();
auto resultTupleType =
shapeCastOp.result().getType().dyn_cast_or_null<TupleType>();
if (!sourceTupleType || !resultTupleType)
return matchFailure();
return failure();
assert(sourceTupleType.size() == resultTupleType.size());
// Create single-vector ShapeCastOp for each source tuple element.
@ -679,7 +679,7 @@ struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
// Replace 'shapeCastOp' with tuple of 'resultElements'.
rewriter.replaceOpWithNewOp<vector::TupleOp>(shapeCastOp, resultTupleType,
resultElements);
return matchSuccess();
return success();
}
};
@ -702,21 +702,21 @@ struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
// Check if 'shapeCastOp' has vector source/result type.
auto sourceVectorType =
shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
auto resultVectorType =
shapeCastOp.result().getType().dyn_cast_or_null<VectorType>();
if (!sourceVectorType || !resultVectorType)
return matchFailure();
return failure();
// Check if shape cast op source operand is also a shape cast op.
auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
shapeCastOp.source().getDefiningOp());
if (!sourceShapeCastOp)
return matchFailure();
return failure();
auto operandSourceVectorType =
sourceShapeCastOp.source().getType().cast<VectorType>();
auto operandResultVectorType =
@ -725,10 +725,10 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
// Check if shape cast operations invert each other.
if (operandSourceVectorType != resultVectorType ||
operandResultVectorType != sourceVectorType)
return matchFailure();
return failure();
rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source());
return matchSuccess();
return success();
}
};
@ -738,30 +738,30 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
using OpRewritePattern<vector::TupleGetOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
PatternRewriter &rewriter) const override {
// Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp.
auto extractSlicesOp = dyn_cast_or_null<vector::ExtractSlicesOp>(
tupleGetOp.vectors().getDefiningOp());
if (!extractSlicesOp)
return matchFailure();
return failure();
// Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp.
auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(
extractSlicesOp.vector().getDefiningOp());
if (!insertSlicesOp)
return matchFailure();
return failure();
// Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp.
auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
insertSlicesOp.vectors().getDefiningOp());
if (!tupleOp)
return matchFailure();
return failure();
// Forward Value from 'tupleOp' at 'tupleGetOp.index'.
Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
rewriter.replaceOp(tupleGetOp, tupleValue);
return matchSuccess();
return success();
}
};
@ -778,8 +778,8 @@ class ExtractSlicesOpLowering
public:
using OpRewritePattern<vector::ExtractSlicesOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(vector::ExtractSlicesOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::ExtractSlicesOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType vectorType = op.getSourceVectorType();
@ -806,7 +806,7 @@ public:
}
rewriter.replaceOpWithNewOp<vector::TupleOp>(op, tupleType, tupleValues);
return matchSuccess();
return success();
}
};
@ -825,8 +825,8 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
public:
using OpRewritePattern<vector::InsertSlicesOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(vector::InsertSlicesOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::InsertSlicesOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType vectorType = op.getResultVectorType();
@ -860,7 +860,7 @@ public:
}
rewriter.replaceOp(op, result);
return matchSuccess();
return success();
}
};
@ -881,8 +881,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
public:
using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType rhsType = op.getOperandVectorTypeRHS();
@ -907,7 +907,7 @@ public:
result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
}
rewriter.replaceOp(op, result);
return matchSuccess();
return success();
}
};
@ -934,11 +934,11 @@ public:
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformsOptions(vectorTransformsOptions) {}
PatternMatchResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
// TODO(ajcbik): implement masks
if (llvm::size(op.masks()) != 0)
return matchFailure();
return failure();
// TODO(ntv, ajcbik): implement benefits, cost models, separate this out in
// a new pattern.
@ -977,7 +977,7 @@ public:
rewriter.replaceOpWithNewOp<AddIOp>(op, op.acc(), mul);
else
rewriter.replaceOpWithNewOp<AddFOp>(op, op.acc(), mul);
return matchSuccess();
return success();
}
}
@ -987,7 +987,7 @@ public:
int64_t lhsIndex = batchDimMap[0].first;
int64_t rhsIndex = batchDimMap[0].second;
rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
return matchSuccess();
return success();
}
// Collect contracting dimensions.
@ -1007,7 +1007,7 @@ public:
if (lhsContractingDimSet.count(lhsIndex) == 0) {
rewriter.replaceOp(
op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
return matchSuccess();
return success();
}
}
@ -1018,17 +1018,17 @@ public:
if (rhsContractingDimSet.count(rhsIndex) == 0) {
rewriter.replaceOp(
op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
return matchSuccess();
return success();
}
}
// Lower the first remaining reduction dimension.
if (!contractingDimMap.empty()) {
rewriter.replaceOp(op, lowerReduction(op, rewriter));
return matchSuccess();
return success();
}
return matchFailure();
return failure();
}
private:
@ -1275,12 +1275,12 @@ class ShapeCastOp2DDownCastRewritePattern
public:
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
return matchFailure();
return failure();
auto loc = op.getLoc();
auto elemType = sourceVectorType.getElementType();
@ -1295,7 +1295,7 @@ public:
/*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
}
rewriter.replaceOp(op, desc);
return matchSuccess();
return success();
}
};
@ -1309,12 +1309,12 @@ class ShapeCastOp2DUpCastRewritePattern
public:
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
return matchFailure();
return failure();
auto loc = op.getLoc();
auto elemType = sourceVectorType.getElementType();
@ -1330,7 +1330,7 @@ public:
desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
}
rewriter.replaceOp(op, desc);
return matchSuccess();
return success();
}
};

View File

@ -44,7 +44,7 @@ void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
"rewrite functions!");
}
PatternMatchResult RewritePattern::match(Operation *op) const {
LogicalResult RewritePattern::match(Operation *op) const {
llvm_unreachable("need to implement either match or matchAndRewrite!");
}

View File

@ -35,13 +35,13 @@ public:
RemoveIdentityOpRewrite(MLIRContext *context)
: RewritePattern(OpTy::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
assert(op->getNumOperands() == 1);
assert(op->getNumResults() == 1);
rewriter.replaceOp(op, op->getOperand(0));
return matchSuccess();
return success();
}
};

View File

@ -1010,7 +1010,7 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
//===----------------------------------------------------------------------===//
/// Attempt to match and rewrite the IR root at the specified operation.
PatternMatchResult
LogicalResult
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
SmallVector<Value, 4> operands;
@ -1705,7 +1705,7 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
: OpConversionPattern(ctx), converter(converter) {}
/// Hook for derived classes to implement combined matching and rewriting.
PatternMatchResult
LogicalResult
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FunctionType type = funcOp.getType();
@ -1714,12 +1714,12 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
TypeConverter::SignatureConversion result(type.getNumInputs());
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
return matchFailure();
return failure();
// Convert the original function results.
SmallVector<Type, 1> convertedResults;
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
return matchFailure();
return failure();
// Update the function signature in-place.
rewriter.updateRootInPlace(funcOp, [&] {
@ -1727,7 +1727,7 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
convertedResults, funcOp.getContext()));
rewriter.applySignatureConversion(&funcOp.getBody(), result);
});
return matchSuccess();
return success();
}
/// The type converter to use when rewriting the signature.

View File

@ -94,32 +94,32 @@ struct ConvertToTargetEnv : public FunctionPass<ConvertToTargetEnv> {
struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
ConvertToAtomCmpExchangeWeak(MLIRContext *context);
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
};
struct ConvertToBitReverse : public RewritePattern {
ConvertToBitReverse(MLIRContext *context);
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
};
struct ConvertToGroupNonUniformBallot : public RewritePattern {
ConvertToGroupNonUniformBallot(MLIRContext *context);
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
};
struct ConvertToModule : public RewritePattern {
ConvertToModule(MLIRContext *context);
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
};
struct ConvertToSubgroupBallot : public RewritePattern {
ConvertToSubgroupBallot(MLIRContext *context);
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
};
} // end anonymous namespace
@ -145,7 +145,7 @@ ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
: RewritePattern("test.convert_to_atomic_compare_exchange_weak_op",
{"spv.AtomicCompareExchangeWeak"}, 1, context) {}
PatternMatchResult
LogicalResult
ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Value ptr = op->getOperand(0);
@ -159,21 +159,21 @@ ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
spirv::MemorySemantics::AcquireRelease |
spirv::MemorySemantics::AtomicCounterMemory,
spirv::MemorySemantics::Acquire, value, comparator);
return matchSuccess();
return success();
}
ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
: RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1,
context) {}
PatternMatchResult
LogicalResult
ConvertToBitReverse::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Value predicate = op->getOperand(0);
rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
op, op->getResult(0).getType(), predicate);
return matchSuccess();
return success();
}
ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
@ -181,39 +181,39 @@ ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
: RewritePattern("test.convert_to_group_non_uniform_ballot_op",
{"spv.GroupNonUniformBallot"}, 1, context) {}
PatternMatchResult ConvertToGroupNonUniformBallot::matchAndRewrite(
LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
Value predicate = op->getOperand(0);
rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
return matchSuccess();
return success();
}
ConvertToModule::ConvertToModule(MLIRContext *context)
: RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {}
PatternMatchResult
LogicalResult
ConvertToModule::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
op, spirv::AddressingModel::PhysicalStorageBuffer64,
spirv::MemoryModel::Vulkan);
return matchSuccess();
return success();
}
ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
: RewritePattern("test.convert_to_subgroup_ballot_op",
{"spv.SubgroupBallotKHR"}, 1, context) {}
PatternMatchResult
LogicalResult
ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Value predicate = op->getOperand(0);
rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
op, op->getResult(0).getType(), predicate);
return matchSuccess();
return success();
}
namespace mlir {

View File

@ -283,10 +283,10 @@ struct TestRemoveOpWithInnerOps
: public OpRewritePattern<TestOpWithRegionPattern> {
using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
PatternMatchResult matchAndRewrite(TestOpWithRegionPattern op,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
};
} // end anonymous namespace

View File

@ -141,7 +141,7 @@ struct TestRegionRewriteBlockMovement : public ConversionPattern {
TestRegionRewriteBlockMovement(MLIRContext *ctx)
: ConversionPattern("test.region", 1, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Inline this region into the parent region.
@ -155,7 +155,7 @@ struct TestRegionRewriteBlockMovement : public ConversionPattern {
// Drop this operation.
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
};
/// This pattern is a simple pattern that generates a region containing an
@ -164,8 +164,8 @@ struct TestRegionRewriteUndo : public RewritePattern {
TestRegionRewriteUndo(MLIRContext *ctx)
: RewritePattern("test.region_builder", 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
// Create the region operation with an entry block containing arguments.
OperationState newRegion(op->getLoc(), "test.region");
newRegion.addRegion();
@ -179,7 +179,7 @@ struct TestRegionRewriteUndo : public RewritePattern {
// Drop this operation.
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
};
@ -191,7 +191,7 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
: ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) {
}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Region &region = op->getRegion(0);
@ -202,12 +202,12 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i)
if (failed(converter.convertSignatureArg(
i, entry->getArgument(i).getType(), result)))
return matchFailure();
return failure();
// Convert the region signature and just drop the operation.
rewriter.applySignatureConversion(&region, result);
rewriter.eraseOp(op);
return matchSuccess();
return success();
}
/// The type converter to use when rewriting the signature.
@ -217,35 +217,35 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
struct TestPassthroughInvalidOp : public ConversionPattern {
TestPassthroughInvalidOp(MLIRContext *ctx)
: ConversionPattern("test.invalid", 1, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
llvm::None);
return matchSuccess();
return success();
}
};
/// This pattern handles the case of a split return value.
struct TestSplitReturnType : public ConversionPattern {
TestSplitReturnType(MLIRContext *ctx)
: ConversionPattern("test.return", 1, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Check for a return of F32.
if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
return matchFailure();
return failure();
// Check if the first operation is a cast operation, if it is we use the
// results directly.
auto *defOp = operands[0].getDefiningOp();
if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
return matchSuccess();
return success();
}
// Otherwise, fail to match.
return matchFailure();
return failure();
}
};
@ -254,52 +254,52 @@ struct TestSplitReturnType : public ConversionPattern {
struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
TestChangeProducerTypeI32ToF32(MLIRContext *ctx)
: ConversionPattern("test.type_producer", 1, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// If the type is I32, change the type to F32.
if (!Type(*op->result_type_begin()).isSignlessInteger(32))
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
return matchSuccess();
return success();
}
};
struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
TestChangeProducerTypeF32ToF64(MLIRContext *ctx)
: ConversionPattern("test.type_producer", 1, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// If the type is F32, change the type to F64.
if (!Type(*op->result_type_begin()).isF32())
return rewriter.notifyMatchFailure(op, "expected single f32 operand");
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
return matchSuccess();
return success();
}
};
struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern {
TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx)
: ConversionPattern("test.type_producer", 10, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Always convert to B16, even though it is not a legal type. This tests
// that values are unmapped correctly.
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type());
return matchSuccess();
return success();
}
};
struct TestUpdateConsumerType : public ConversionPattern {
TestUpdateConsumerType(MLIRContext *ctx)
: ConversionPattern("test.type_consumer", 1, ctx) {}
PatternMatchResult
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Verify that the incoming operand has been successfully remapped to F64.
if (!operands[0].getType().isF64())
return matchFailure();
return failure();
rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]);
return matchSuccess();
return success();
}
};
@ -312,15 +312,15 @@ struct TestNonRootReplacement : public RewritePattern {
TestNonRootReplacement(MLIRContext *ctx)
: RewritePattern("test.replace_non_root", 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
auto resultType = *op->result_type_begin();
auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
rewriter.replaceOp(illegalOp, {legalOp});
rewriter.replaceOp(op, {illegalOp});
return matchSuccess();
return success();
}
};
} // namespace
@ -475,7 +475,7 @@ struct OneVResOneVOperandOp1Converter
: public OpConversionPattern<OneVResOneVOperandOp1> {
using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern;
PatternMatchResult
LogicalResult
matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto origOps = op.getOperands();
@ -490,7 +490,7 @@ struct OneVResOneVOperandOp1Converter
rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
remappedOperands);
return matchSuccess();
return success();
}
};

View File

@ -215,7 +215,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
// Skip if there is no defining operation (e.g., arguments to function).
os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
os.indent(indent) << formatv("if (!castedOp{0}) return failure();\n",
depth);
}
if (tree.getNumArgs() != op.getNumArgs()) {
@ -300,7 +300,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
os.indent(indent) << "if (!("
<< std::string(tgfmt(matcher.getConditionTemplate(),
&fmtCtx.withSelf(self)))
<< ")) return matchFailure();\n";
<< ")) return failure();\n";
}
}
@ -344,7 +344,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
// should just capture a mlir::Attribute() to signal the missing state.
// That is precisely what getAttr() returns on missing attributes.
} else {
os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
os.indent(indent) << "if (!tblgen_attr) return failure();\n";
}
auto matcher = tree.getArgAsLeaf(argIndex);
@ -360,7 +360,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
os.indent(indent) << "if (!("
<< std::string(tgfmt(matcher.getConditionTemplate(),
&fmtCtx.withSelf("tblgen_attr")))
<< ")) return matchFailure();\n";
<< ")) return failure();\n";
}
// Capture the value
@ -383,7 +383,7 @@ void PatternEmitter::emitMatchLogic(DagNode tree) {
auto &entities = appliedConstraint.entities;
auto condition = constraint.getConditionTemplate();
auto cmd = "if (!({0})) return matchFailure();\n";
auto cmd = "if (!({0})) return failure();\n";
if (isa<TypeConstraint>(constraint)) {
auto self = formatv("({0}.getType())",
@ -468,7 +468,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
// Emit matchAndRewrite() function.
os << R"(
PatternMatchResult matchAndRewrite(Operation *op0,
LogicalResult matchAndRewrite(Operation *op0,
PatternRewriter &rewriter) const override {
)";
@ -501,7 +501,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
os.indent(4) << "// Rewrite\n";
emitRewriteLogic();
os.indent(4) << "return matchSuccess();\n";
os.indent(4) << "return success();\n";
os << " };\n";
os << "};\n";
}