diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td index 1387a854f605..7321938f2d4f 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td @@ -599,6 +599,8 @@ def SPV_LogicalNotOp : SPV_LogicalUnaryOp<"LogicalNot", SPV_Bool, []> { %2 = spv.LogicalNot %0 : vector<4xi1> ``` }]; + + let hasCanonicalizer = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 4c9dd5b79d20..e463f8912efb 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1583,6 +1583,47 @@ static LogicalResult verify(spirv::LoadOp loadOp) { return verifyMemoryAccessAttribute(loadOp); } +//===----------------------------------------------------------------------===// +// spv.LogicalNot +//===----------------------------------------------------------------------===// + +namespace { + +/// Converts `spirv::LogicalNotOp` to the given `NewOp` using the first and the +/// second operands from the given `ParentOp`. +template +struct ConvertLogicalNotOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(spirv::LogicalNotOp logicalNotOp, + PatternRewriter &rewriter) const override { + auto parentOp = + dyn_cast_or_null(logicalNotOp.operand()->getDefiningOp()); + + if (!parentOp) { + return this->matchFailure(); + } + + rewriter.replaceOpWithNewOp( + /*valuesToRemoveIfDead=*/{parentOp.result()}, logicalNotOp, + logicalNotOp.result()->getType(), parentOp.operand1(), + parentOp.operand2()); + + return this->matchSuccess(); + } +}; +} // end anonymous namespace + +void spirv::LogicalNotOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert< + ConvertLogicalNotOp, + ConvertLogicalNotOp, + ConvertLogicalNotOp, + ConvertLogicalNotOp>( + context); +} + //===----------------------------------------------------------------------===// // spv.loop //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir index e1b9c1b8411d..85998fb03efd 100644 --- a/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir @@ -126,6 +126,18 @@ func @iadd_scalar(%arg: i32) -> i32 { // ----- +//===----------------------------------------------------------------------===// +// spv.IMul +//===----------------------------------------------------------------------===// + +func @imul_scalar(%arg: i32) -> i32 { + // CHECK: spv.IMul + %0 = spv.IMul %arg, %arg : i32 + return %0 : i32 +} + +// ----- + //===----------------------------------------------------------------------===// // spv.ISub //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir index 87be892ce3bd..84c2544762bc 100644 --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -391,3 +391,60 @@ func @cannot_canonicalize_selection_op_4(%cond: i1) -> () { } spv.Return } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.LogicalNot +//===----------------------------------------------------------------------===// + +func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> { + // CHECK: %[[RESULT:.*]] = spv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64> + // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1> + %2 = spv.IEqual %arg0, %arg1 : vector<3xi64> + %3 = spv.LogicalNot %2 : vector<3xi1> + spv.ReturnValue %3 : vector<3xi1> +} + +// ----- + +func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> { + // CHECK: %[[RESULT:.*]] = spv.IEqual {{%.*}}, {{%.*}} : vector<3xi64> + // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1> + %2 = spv.INotEqual %arg0, %arg1 : vector<3xi64> + %3 = spv.LogicalNot %2 : vector<3xi1> + spv.ReturnValue %3 : vector<3xi1> +} + +// ----- + +func @convert_logical_not_parent_multi_use(%arg0: vector<3xi64>, %arg1: vector<3xi64>, %arg2: !spv.ptr, Uniform>) -> vector<3xi1> { + // CHECK: %[[RESULT_0:.*]] = spv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64> + // CHECK-NEXT: %[[RESULT_1:.*]] = spv.IEqual {{%.*}}, {{%.*}} : vector<3xi64> + // CHECK-NEXT: spv.Store "Uniform" {{%.*}}, %[[RESULT_0]] + // CHECK-NEXT: spv.ReturnValue %[[RESULT_1]] + %0 = spv.INotEqual %arg0, %arg1 : vector<3xi64> + %1 = spv.LogicalNot %0 : vector<3xi1> + spv.Store "Uniform" %arg2, %0 : vector<3xi1> + spv.ReturnValue %1 : vector<3xi1> +} + +// ----- + +func @convert_logical_not_to_logical_not_equal(%arg0: vector<3xi1>, %arg1: vector<3xi1>) -> vector<3xi1> { + // CHECK: %[[RESULT:.*]] = spv.LogicalNotEqual {{%.*}}, {{%.*}} : vector<3xi1> + // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1> + %2 = spv.LogicalEqual %arg0, %arg1 : vector<3xi1> + %3 = spv.LogicalNot %2 : vector<3xi1> + spv.ReturnValue %3 : vector<3xi1> +} + +// ----- + +func @convert_logical_not_to_logical_equal(%arg0: vector<3xi1>, %arg1: vector<3xi1>) -> vector<3xi1> { + // CHECK: %[[RESULT:.*]] = spv.LogicalEqual {{%.*}}, {{%.*}} : vector<3xi1> + // CHECK-NEXT: spv.ReturnValue %[[RESULT]] : vector<3xi1> + %2 = spv.LogicalNotEqual %arg0, %arg1 : vector<3xi1> + %3 = spv.LogicalNot %2 : vector<3xi1> + spv.ReturnValue %3 : vector<3xi1> +} diff --git a/mlir/test/Dialect/SPIRV/logical-ops.mlir b/mlir/test/Dialect/SPIRV/logical-ops.mlir index 436e32d8a4a7..d102ae98d3ae 100644 --- a/mlir/test/Dialect/SPIRV/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/logical-ops.mlir @@ -32,18 +32,6 @@ func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi // ----- -//===----------------------------------------------------------------------===// -// spv.IMul -//===----------------------------------------------------------------------===// - -func @imul_scalar(%arg: i32) -> i32 { - // CHECK: spv.IMul - %0 = spv.IMul %arg, %arg : i32 - return %0 : i32 -} - -// ----- - //===----------------------------------------------------------------------===// // spv.SGreaterThan //===----------------------------------------------------------------------===//