Add xor bitwise operation to StandardOps.

This adds parsing, printing and some folding/canonicalization.

    Also extends rewriting of subi %0, %0 to handle vectors and tensors.

--

PiperOrigin-RevId: 242448164
This commit is contained in:
Stephan Herhut 2019-04-08 05:53:59 -07:00 committed by Mehdi Amini
parent ca89e7167d
commit af016ba7a4
5 changed files with 132 additions and 1 deletions

View File

@ -2204,6 +2204,35 @@ same element type, and the source and destination types may not be the same.
They must either have the same rank, or one may be an unknown rank. The
operation is invalid if converting to a mismatching constant dimension.
#### 'xor' operation
Bitwise integer xor.
Syntax:
``` {.ebnf}
operation ::= ssa-id `=` `xor` ssa-use, ssa-use `:` type
```
Examples:
```mlir {.mlir}
// Scalar integer bitwise xor.
%a = xor %b, %c : i64
// SIMD vector element-wise bitwise integer xor.
%f = xor %g, %h : vector<4xi32>
// Tensor element-wise bitwise integer xor.
%x = xor %y, %z : tensor<4x?xi8>
```
The `xor` operation takes two operands and returns one result, each of these is
required to be the same type. This type may be an integer scalar type, a vector
whose element type is integer, or a tensor of integers. It has no standard
attributes.
## Dialects
MLIR supports multiple dialects containing a set of operations and types defined

View File

@ -142,4 +142,11 @@ def SubIOp : IntArithmeticOp<"std.subi"> {
let hasCanonicalizer = 0b1;
}
def XOrOp : IntArithmeticOp<"std.xor", [Commutative]> {
let summary = "integer binary xor";
let hasConstantFolder = 0b1;
let hasCanonicalizer = 0b1;
let hasFolder = 1;
}
#endif // STANDARD_OPS

View File

@ -1991,7 +1991,8 @@ struct SimplifyXMinusX : public RewritePattern {
if (subi.getOperand(0) != subi.getOperand(1))
return matchFailure();
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, 0, subi.getType());
rewriter.replaceOpWithNewOp<ConstantOp>(
op, subi.getType(), rewriter.getZeroAttr(subi.getType()));
return matchSuccess();
}
};
@ -2044,6 +2045,48 @@ Value *OrOp::fold() {
return nullptr;
}
//===----------------------------------------------------------------------===//
// XOrOp
//===----------------------------------------------------------------------===//
Attribute XOrOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) {
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a ^ b; });
}
Value *XOrOp::fold() {
/// xor(x, 0) -> x
if (matchPattern(rhs(), m_Zero()))
return lhs();
return nullptr;
}
namespace {
/// xor(x,x) -> 0
///
struct SimplifyXXOrX : public RewritePattern {
SimplifyXXOrX(MLIRContext *context)
: RewritePattern(XOrOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto xorOp = op->cast<XOrOp>();
if (xorOp.lhs() != xorOp.rhs())
return matchFailure();
rewriter.replaceOpWithNewOp<ConstantOp>(
op, xorOp.getType(), rewriter.getZeroAttr(xorOp.getType()));
return matchSuccess();
}
};
} // end anonymous namespace.
void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.push_back(llvm::make_unique<SimplifyXXOrX>(context));
}
//===----------------------------------------------------------------------===//
// TensorCastOp
//===----------------------------------------------------------------------===//

View File

@ -241,6 +241,18 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index) {
// CHECK: %{{[0-9]+}} = or %cst_4, %cst_4 : tensor<42xi32>
%59 = or %tci32, %tci32 : tensor<42 x i32>
// CHECK: %{{[0-9]+}} = xor %arg2, %arg2 : i32
%60 = "std.xor"(%i, %i) : (i32,i32) -> i32
// CHECK: %{{[0-9]+}} = xor %arg2, %arg2 : i32
%61 = xor %i, %i : i32
// CHECK: %{{[0-9]+}} = xor %cst_5, %cst_5 : vector<42xi32>
%62 = std.xor %vci32, %vci32 : vector<42 x i32>
// CHECK: %{{[0-9]+}} = xor %cst_4, %cst_4 : tensor<42xi32>
%63 = xor %tci32, %tci32 : tensor<42 x i32>
return
}

View File

@ -8,6 +8,22 @@ func @test_subi_zero(%arg0: i32) -> i32 {
return %y: i32
}
// CHECK-LABEL: func @test_subi_zero_vector
func @test_subi_zero_vector(%arg0: vector<4xi32>) -> vector<4xi32> {
//CHECK-NEXT: %cst = constant splat<vector<4xi32>, 0>
%y = subi %arg0, %arg0 : vector<4xi32>
// CHECK-NEXT: return %cst
return %y: vector<4xi32>
}
// CHECK-LABEL: func @test_subi_zero_tensor
func @test_subi_zero_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
//CHECK-NEXT: %cst = constant splat<tensor<4x5xi32>, 0>
%y = subi %arg0, %arg0 : tensor<4x5xi32>
// CHECK-NEXT: return %cst
return %y: tensor<4x5xi32>
}
// CHECK-LABEL: func @dim
func @dim(%arg0: tensor<8x4xf32>) -> index {
@ -214,6 +230,30 @@ func @or_zero_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
return %1 : tensor<4x5xi32>
}
//CHECK-LABEL: func @xor_self
func @xor_self(%arg0: i32) -> i32 {
//CHECK-NEXT: %c0_i32 = constant 0
%1 = xor %arg0, %arg0 : i32
//CHECK-NEXT: return %c0_i32
return %1 : i32
}
//CHECK-LABEL: func @xor_self_vector
func @xor_self_vector(%arg0: vector<4xi32>) -> vector<4xi32> {
//CHECK-NEXT: %cst = constant splat<vector<4xi32>, 0>
%1 = xor %arg0, %arg0 : vector<4xi32>
//CHECK-NEXT: return %cst
return %1 : vector<4xi32>
}
//CHECK-LABEL: func @xor_self_tensor
func @xor_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
//CHECK-NEXT: %cst = constant splat<tensor<4x5xi32>, 0>
%1 = xor %arg0, %arg0 : tensor<4x5xi32>
//CHECK-NEXT: return %cst
return %1 : tensor<4x5xi32>
}
// CHECK-LABEL: func @memref_cast_folding
func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 {
%1 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>