forked from OSchip/llvm-project
Add xor bitwise operation to StandardOps.
This adds parsing, printing and some folding/canonicalization. Also extends rewriting of subi %0, %0 to handle vectors and tensors. -- PiperOrigin-RevId: 242448164
This commit is contained in:
parent
ca89e7167d
commit
af016ba7a4
|
@ -2204,6 +2204,35 @@ same element type, and the source and destination types may not be the same.
|
|||
They must either have the same rank, or one may be an unknown rank. The
|
||||
operation is invalid if converting to a mismatching constant dimension.
|
||||
|
||||
|
||||
#### 'xor' operation
|
||||
|
||||
Bitwise integer xor.
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
operation ::= ssa-id `=` `xor` ssa-use, ssa-use `:` type
|
||||
```
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir {.mlir}
|
||||
// Scalar integer bitwise xor.
|
||||
%a = xor %b, %c : i64
|
||||
|
||||
// SIMD vector element-wise bitwise integer xor.
|
||||
%f = xor %g, %h : vector<4xi32>
|
||||
|
||||
// Tensor element-wise bitwise integer xor.
|
||||
%x = xor %y, %z : tensor<4x?xi8>
|
||||
```
|
||||
|
||||
The `xor` operation takes two operands and returns one result, each of these is
|
||||
required to be the same type. This type may be an integer scalar type, a vector
|
||||
whose element type is integer, or a tensor of integers. It has no standard
|
||||
attributes.
|
||||
|
||||
## Dialects
|
||||
|
||||
MLIR supports multiple dialects containing a set of operations and types defined
|
||||
|
|
|
@ -142,4 +142,11 @@ def SubIOp : IntArithmeticOp<"std.subi"> {
|
|||
let hasCanonicalizer = 0b1;
|
||||
}
|
||||
|
||||
def XOrOp : IntArithmeticOp<"std.xor", [Commutative]> {
|
||||
let summary = "integer binary xor";
|
||||
let hasConstantFolder = 0b1;
|
||||
let hasCanonicalizer = 0b1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
#endif // STANDARD_OPS
|
||||
|
|
|
@ -1991,7 +1991,8 @@ struct SimplifyXMinusX : public RewritePattern {
|
|||
if (subi.getOperand(0) != subi.getOperand(1))
|
||||
return matchFailure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, 0, subi.getType());
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(
|
||||
op, subi.getType(), rewriter.getZeroAttr(subi.getType()));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -2044,6 +2045,48 @@ Value *OrOp::fold() {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XOrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute XOrOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) {
|
||||
return constFoldBinaryOp<IntegerAttr>(operands,
|
||||
[](APInt a, APInt b) { return a ^ b; });
|
||||
}
|
||||
|
||||
Value *XOrOp::fold() {
|
||||
/// xor(x, 0) -> x
|
||||
if (matchPattern(rhs(), m_Zero()))
|
||||
return lhs();
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// xor(x,x) -> 0
|
||||
///
|
||||
struct SimplifyXXOrX : public RewritePattern {
|
||||
SimplifyXXOrX(MLIRContext *context)
|
||||
: RewritePattern(XOrOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto xorOp = op->cast<XOrOp>();
|
||||
if (xorOp.lhs() != xorOp.rhs())
|
||||
return matchFailure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(
|
||||
op, xorOp.getType(), rewriter.getZeroAttr(xorOp.getType()));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.push_back(llvm::make_unique<SimplifyXXOrX>(context));
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -241,6 +241,18 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index) {
|
|||
// CHECK: %{{[0-9]+}} = or %cst_4, %cst_4 : tensor<42xi32>
|
||||
%59 = or %tci32, %tci32 : tensor<42 x i32>
|
||||
|
||||
// CHECK: %{{[0-9]+}} = xor %arg2, %arg2 : i32
|
||||
%60 = "std.xor"(%i, %i) : (i32,i32) -> i32
|
||||
|
||||
// CHECK: %{{[0-9]+}} = xor %arg2, %arg2 : i32
|
||||
%61 = xor %i, %i : i32
|
||||
|
||||
// CHECK: %{{[0-9]+}} = xor %cst_5, %cst_5 : vector<42xi32>
|
||||
%62 = std.xor %vci32, %vci32 : vector<42 x i32>
|
||||
|
||||
// CHECK: %{{[0-9]+}} = xor %cst_4, %cst_4 : tensor<42xi32>
|
||||
%63 = xor %tci32, %tci32 : tensor<42 x i32>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,22 @@ func @test_subi_zero(%arg0: i32) -> i32 {
|
|||
return %y: i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_subi_zero_vector
|
||||
func @test_subi_zero_vector(%arg0: vector<4xi32>) -> vector<4xi32> {
|
||||
//CHECK-NEXT: %cst = constant splat<vector<4xi32>, 0>
|
||||
%y = subi %arg0, %arg0 : vector<4xi32>
|
||||
// CHECK-NEXT: return %cst
|
||||
return %y: vector<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_subi_zero_tensor
|
||||
func @test_subi_zero_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
|
||||
//CHECK-NEXT: %cst = constant splat<tensor<4x5xi32>, 0>
|
||||
%y = subi %arg0, %arg0 : tensor<4x5xi32>
|
||||
// CHECK-NEXT: return %cst
|
||||
return %y: tensor<4x5xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dim
|
||||
func @dim(%arg0: tensor<8x4xf32>) -> index {
|
||||
|
||||
|
@ -214,6 +230,30 @@ func @or_zero_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
|
|||
return %1 : tensor<4x5xi32>
|
||||
}
|
||||
|
||||
//CHECK-LABEL: func @xor_self
|
||||
func @xor_self(%arg0: i32) -> i32 {
|
||||
//CHECK-NEXT: %c0_i32 = constant 0
|
||||
%1 = xor %arg0, %arg0 : i32
|
||||
//CHECK-NEXT: return %c0_i32
|
||||
return %1 : i32
|
||||
}
|
||||
|
||||
//CHECK-LABEL: func @xor_self_vector
|
||||
func @xor_self_vector(%arg0: vector<4xi32>) -> vector<4xi32> {
|
||||
//CHECK-NEXT: %cst = constant splat<vector<4xi32>, 0>
|
||||
%1 = xor %arg0, %arg0 : vector<4xi32>
|
||||
//CHECK-NEXT: return %cst
|
||||
return %1 : vector<4xi32>
|
||||
}
|
||||
|
||||
//CHECK-LABEL: func @xor_self_tensor
|
||||
func @xor_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
|
||||
//CHECK-NEXT: %cst = constant splat<tensor<4x5xi32>, 0>
|
||||
%1 = xor %arg0, %arg0 : tensor<4x5xi32>
|
||||
//CHECK-NEXT: return %cst
|
||||
return %1 : tensor<4x5xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_folding
|
||||
func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 {
|
||||
%1 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
|
||||
|
|
Loading…
Reference in New Issue