[spirv] Add a canonicalizer for `spirv::LogicalNotOp`.

Add a canonicalizer for `spirv::LogicalNotOp`.
Converts:
* spv.LogicalNot(spv.IEqual(...)) -> spv.INotEqual(...)
* spv.LogicalNot(spv.INotEqual(...)) -> spv.IEqual(...)
* spv.LogicalNot(spv.LogicalEqual(...)) -> spv.LogicalNotEqual(...)
* spv.LogicalNot(spv.LogicalNotEqual(...)) -> spv.LogicalEqual(...)

Also moved the test for spv.IMul to arithemtic tests.

Closes tensorflow/mlir#256

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/256 from denis0x0D:sandbox/canon_logical_not 76ab5787b2c777f948c8978db061d99e76453d44
PiperOrigin-RevId: 282012356
This commit is contained in:
Denis Khalikov 2019-11-22 11:49:22 -08:00 committed by A. Unique TensorFlower
parent 6db8530c26
commit a5cda4763f
5 changed files with 112 additions and 12 deletions

View File

@ -599,6 +599,8 @@ def SPV_LogicalNotOp : SPV_LogicalUnaryOp<"LogicalNot", SPV_Bool, []> {
%2 = spv.LogicalNot %0 : vector<4xi1>
```
}];
let hasCanonicalizer = 1;
}
// -----

View File

@ -1583,6 +1583,47 @@ static LogicalResult verify(spirv::LoadOp loadOp) {
return verifyMemoryAccessAttribute(loadOp);
}
//===----------------------------------------------------------------------===//
// spv.LogicalNot
//===----------------------------------------------------------------------===//
namespace {
/// Converts `spirv::LogicalNotOp` to the given `NewOp` using the first and the
/// second operands from the given `ParentOp`.
template <typename NewOp, typename ParentOp>
struct ConvertLogicalNotOp : public OpRewritePattern<spirv::LogicalNotOp> {
using OpRewritePattern<spirv::LogicalNotOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(spirv::LogicalNotOp logicalNotOp,
PatternRewriter &rewriter) const override {
auto parentOp =
dyn_cast_or_null<ParentOp>(logicalNotOp.operand()->getDefiningOp());
if (!parentOp) {
return this->matchFailure();
}
rewriter.replaceOpWithNewOp<NewOp>(
/*valuesToRemoveIfDead=*/{parentOp.result()}, logicalNotOp,
logicalNotOp.result()->getType(), parentOp.operand1(),
parentOp.operand2());
return this->matchSuccess();
}
};
} // end anonymous namespace
void spirv::LogicalNotOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<
ConvertLogicalNotOp<spirv::INotEqualOp, spirv::IEqualOp>,
ConvertLogicalNotOp<spirv::IEqualOp, spirv::INotEqualOp>,
ConvertLogicalNotOp<spirv::LogicalNotEqualOp, spirv::LogicalEqualOp>,
ConvertLogicalNotOp<spirv::LogicalEqualOp, spirv::LogicalNotEqualOp>>(
context);
}
//===----------------------------------------------------------------------===//
// spv.loop
//===----------------------------------------------------------------------===//

View File

@ -126,6 +126,18 @@ func @iadd_scalar(%arg: i32) -> i32 {
// -----
//===----------------------------------------------------------------------===//
// spv.IMul
//===----------------------------------------------------------------------===//
func @imul_scalar(%arg: i32) -> i32 {
// CHECK: spv.IMul
%0 = spv.IMul %arg, %arg : i32
return %0 : i32
}
// -----
//===----------------------------------------------------------------------===//
// spv.ISub
//===----------------------------------------------------------------------===//

View File

@ -391,3 +391,60 @@ func @cannot_canonicalize_selection_op_4(%cond: i1) -> () {
}
spv.Return
}
// -----
//===----------------------------------------------------------------------===//
// spv.LogicalNot
//===----------------------------------------------------------------------===//
func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
// CHECK: %[[RESULT:.*]] = spv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64>
// CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1>
%2 = spv.IEqual %arg0, %arg1 : vector<3xi64>
%3 = spv.LogicalNot %2 : vector<3xi1>
spv.ReturnValue %3 : vector<3xi1>
}
// -----
func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
// CHECK: %[[RESULT:.*]] = spv.IEqual {{%.*}}, {{%.*}} : vector<3xi64>
// CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1>
%2 = spv.INotEqual %arg0, %arg1 : vector<3xi64>
%3 = spv.LogicalNot %2 : vector<3xi1>
spv.ReturnValue %3 : vector<3xi1>
}
// -----
func @convert_logical_not_parent_multi_use(%arg0: vector<3xi64>, %arg1: vector<3xi64>, %arg2: !spv.ptr<vector<3xi1>, Uniform>) -> vector<3xi1> {
// CHECK: %[[RESULT_0:.*]] = spv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64>
// CHECK-NEXT: %[[RESULT_1:.*]] = spv.IEqual {{%.*}}, {{%.*}} : vector<3xi64>
// CHECK-NEXT: spv.Store "Uniform" {{%.*}}, %[[RESULT_0]]
// CHECK-NEXT: spv.ReturnValue %[[RESULT_1]]
%0 = spv.INotEqual %arg0, %arg1 : vector<3xi64>
%1 = spv.LogicalNot %0 : vector<3xi1>
spv.Store "Uniform" %arg2, %0 : vector<3xi1>
spv.ReturnValue %1 : vector<3xi1>
}
// -----
func @convert_logical_not_to_logical_not_equal(%arg0: vector<3xi1>, %arg1: vector<3xi1>) -> vector<3xi1> {
// CHECK: %[[RESULT:.*]] = spv.LogicalNotEqual {{%.*}}, {{%.*}} : vector<3xi1>
// CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1>
%2 = spv.LogicalEqual %arg0, %arg1 : vector<3xi1>
%3 = spv.LogicalNot %2 : vector<3xi1>
spv.ReturnValue %3 : vector<3xi1>
}
// -----
func @convert_logical_not_to_logical_equal(%arg0: vector<3xi1>, %arg1: vector<3xi1>) -> vector<3xi1> {
// CHECK: %[[RESULT:.*]] = spv.LogicalEqual {{%.*}}, {{%.*}} : vector<3xi1>
// CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1>
%2 = spv.LogicalNotEqual %arg0, %arg1 : vector<3xi1>
%3 = spv.LogicalNot %2 : vector<3xi1>
spv.ReturnValue %3 : vector<3xi1>
}

View File

@ -32,18 +32,6 @@ func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi
// -----
//===----------------------------------------------------------------------===//
// spv.IMul
//===----------------------------------------------------------------------===//
func @imul_scalar(%arg: i32) -> i32 {
// CHECK: spv.IMul
%0 = spv.IMul %arg, %arg : i32
return %0 : i32
}
// -----
//===----------------------------------------------------------------------===//
// spv.SGreaterThan
//===----------------------------------------------------------------------===//