[spirv] Add bit ops

This CL added op definitions for a few bit operations:

* OpShiftLeftLogical
* OpShiftRightArithmetic
* OpShiftRightLogical
* OpBitCount
* OpBitReverse
* OpNot

Also moved the definition of spv.BitwiseAnd to follow the
lexicographical order.

Closes tensorflow/mlir#215

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/215 from denis0x0D:sandbox/bit_ops d9b0852b689ac6c4879a9740b1740a2357f44d24
PiperOrigin-RevId: 279350470
This commit is contained in:
Denis Khalikov 2019-11-08 11:05:32 -08:00 committed by A. Unique TensorFlower
parent 24f306a22b
commit 4697d657b7
5 changed files with 411 additions and 13 deletions

View File

@ -166,9 +166,15 @@ def SPV_OC_OpFOrdLessThanEqual : I32EnumAttrCase<"OpFOrdLessThanEqual", 188
def SPV_OC_OpFUnordLessThanEqual : I32EnumAttrCase<"OpFUnordLessThanEqual", 189>;
def SPV_OC_OpFOrdGreaterThanEqual : I32EnumAttrCase<"OpFOrdGreaterThanEqual", 190>;
def SPV_OC_OpFUnordGreaterThanEqual : I32EnumAttrCase<"OpFUnordGreaterThanEqual", 191>;
def SPV_OC_OpShiftRightLogical : I32EnumAttrCase<"OpShiftRightLogical", 194>;
def SPV_OC_OpShiftRightArithmetic : I32EnumAttrCase<"OpShiftRightArithmetic", 195>;
def SPV_OC_OpShiftLeftLogical : I32EnumAttrCase<"OpShiftLeftLogical", 196>;
def SPV_OC_OpBitwiseOr : I32EnumAttrCase<"OpBitwiseOr", 197>;
def SPV_OC_OpBitwiseXor : I32EnumAttrCase<"OpBitwiseXor", 198>;
def SPV_OC_OpBitwiseAnd : I32EnumAttrCase<"OpBitwiseAnd", 199>;
def SPV_OC_OpNot : I32EnumAttrCase<"OpNot", 200>;
def SPV_OC_OpBitReverse : I32EnumAttrCase<"OpBitReverse", 204>;
def SPV_OC_OpBitCount : I32EnumAttrCase<"OpBitCount", 205>;
def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>;
def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>;
def SPV_OC_OpPhi : I32EnumAttrCase<"OpPhi", 245>;
@ -213,7 +219,9 @@ def SPV_OpcodeAttr :
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd,
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitReverse, SPV_OC_OpBitCount,
SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpPhi,
SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,

View File

@ -33,6 +33,124 @@ class SPV_BitBinaryOp<string mnemonic, list<OpTrait> traits = []> :
!listconcat(traits,
[NoSideEffect, SameOperandsAndResultType])>;
class SPV_BitUnaryOp<string mnemonic, list<OpTrait> traits = []> :
SPV_UnaryOp<mnemonic, SPV_Integer, SPV_Integer,
!listconcat(traits,
[NoSideEffect, SameOperandsAndResultType])>;
class SPV_ShiftOp<string mnemonic, list<OpTrait> traits = []> :
SPV_BinaryOp<mnemonic, SPV_Integer, SPV_Integer,
!listconcat(traits,
[NoSideEffect, SameOperandsAndResultShape])> {
let parser = [{ return ::parseShiftOp(parser, result); }];
let printer = [{ ::printShiftOp(this->getOperation(), p); }];
let verifier = [{ return ::verifyShiftOp(this->getOperation()); }];
}
// -----
def SPV_BitCountOp : SPV_BitUnaryOp<"BitCount", []> {
let summary = "Count the number of set bits in an object.";
let description = [{
Results are computed per component.
Result Type must be a scalar or vector of integer type. The components
must be wide enough to hold the unsigned Width of Base as an unsigned
value. That is, no sign bit is needed or counted when checking for a
wide enough result width.
Base must be a scalar or vector of integer type. It must have the same
number of components as Result Type.
The result is the unsigned value that is the number of bits in Base that
are 1.
### Custom assembly form
``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
bitcount-op ::= ssa-id `=` `spv.BitCount` ssa-use
`:` integer-scalar-vector-type
```
For example:
```
%2 = spv.BitCount %0: i32
%3 = spv.BitCount %1: vector<4xi32>
```
}];
}
// -----
def SPV_BitReverseOp : SPV_BitUnaryOp<"BitReverse", []> {
let summary = "Reverse the bits in an object.";
let description = [{
Results are computed per component.
Result Type must be a scalar or vector of integer type.
The type of Base must be the same as Result Type.
The bit-number n of the result will be taken from bit-number Width - 1 -
n of Base, where Width is the OpTypeInt operand of the Result Type.
### Custom assembly form
``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
bitreverse-op ::= ssa-id `=` `spv.BitReverse` ssa-use
`:` integer-scalar-vector-type
```
For example:
```
%2 = spv.BitReverse %0 : i32
%3 = spv.BitReverse %1 : vector<4xi32>
```
}];
}
// -----
def SPV_BitwiseAndOp : SPV_BitBinaryOp<"BitwiseAnd", [Commutative]> {
let summary = [{
Result is 1 if both Operand 1 and Operand 2 are 1. Result is 0 if either
Operand 1 or Operand 2 are 0.
}];
let description = [{
Results are computed per component, and within each component, per bit.
Result Type must be a scalar or vector of integer type. The type of
Operand 1 and Operand 2 must be a scalar or vector of integer type.
They must have the same number of components as Result Type. They must
have the same component width as Result Type.
### Custom assembly form
``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
bitwise-and-op ::= ssa-id `=` `spv.BitwiseAnd` ssa-use, ssa-use
`:` integer-scalar-vector-type
```
For example:
```
%2 = spv.BitwiseAnd %0, %1 : i32
%2 = spv.BitwiseAnd %0, %1 : vector<4xi32>
```
}];
}
// -----
def SPV_BitwiseOrOp : SPV_BitBinaryOp<"BitwiseOr", [Commutative]> {
@ -103,34 +221,158 @@ def SPV_BitwiseXorOp : SPV_BitBinaryOp<"BitwiseXor", [Commutative]> {
// -----
def SPV_BitwiseAndOp : SPV_BitBinaryOp<"BitwiseAnd", [Commutative]> {
def SPV_ShiftLeftLogicalOp : SPV_ShiftOp<"ShiftLeftLogical", []> {
let summary = [{
Result is 1 if both Operand 1 and Operand 2 are 1. Result is 0 if either
Operand 1 or Operand 2 are 0.
Shift the bits in Base left by the number of bits specified in Shift.
The least-significant bits will be zero filled.
}];
let description = [{
Results are computed per component, and within each component, per bit.
Result Type must be a scalar or vector of integer type.
Result Type must be a scalar or vector of integer type. The type of
Operand 1 and Operand 2 must be a scalar or vector of integer type.
They must have the same number of components as Result Type. They must
have the same component width as Result Type.
The type of each Base and Shift must be a scalar or vector of integer
type. Base and Shift must have the same number of components. The
number of components and bit width of the type of Base must be the same
as in Result Type.
Shift is treated as unsigned. The result is undefined if Shift is
greater than or equal to the bit width of the components of Base.
The number of components and bit width of Result Type must match those
Base type. All types must be integer types.
Results are computed per component.
### Custom assembly form
``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
bitwise-and-op ::= ssa-id `=` `spv.BitwiseAnd` ssa-use, ssa-use
`:` integer-scalar-vector-type
shift-left-logical-op ::= ssa-id `=` `spv.ShiftLeftLogical`
ssa-use `,` ssa-use `:`
integer-scalar-vector-type `,`
integer-scalar-vector-type
```
For example:
```
%2 = spv.BitwiseAnd %0, %1 : i32
%2 = spv.BitwiseAnd %0, %1 : vector<4xi32>
%2 = spv.ShiftLeftLogical %0, %1 : i32, i16
%5 = spv.ShiftLeftLogical %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
}
// -----
def SPV_ShiftRightArithmeticOp : SPV_ShiftOp<"ShiftRightArithmetic", []> {
let summary = [{
Shift the bits in Base right by the number of bits specified in Shift.
The most-significant bits will be filled with the sign bit from Base.
}];
let description = [{
Result Type must be a scalar or vector of integer type.
The type of each Base and Shift must be a scalar or vector of integer
type. Base and Shift must have the same number of components. The
number of components and bit width of the type of Base must be the same
as in Result Type.
Shift is treated as unsigned. The result is undefined if Shift is
greater than or equal to the bit width of the components of Base.
Results are computed per component.
### Custom assembly form
``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
shift-right-arithmetic-op ::= ssa-id `=` `spv.ShiftRightArithmetic`
ssa-use `,` ssa-use `:`
integer-scalar-vector-type `,`
integer-scalar-vector-type
```
For example:
```
%2 = spv.ShiftRightArithmetic %0, %1 : i32, i16
%5 = spv.ShiftRightArithmetic %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
}
// -----
def SPV_ShiftRightLogicalOp : SPV_ShiftOp<"ShiftRightLogical", []> {
let summary = [{
Shift the bits in Base right by the number of bits specified in Shift.
The most-significant bits will be zero filled.
}];
let description = [{
Result Type must be a scalar or vector of integer type.
The type of each Base and Shift must be a scalar or vector of integer
type. Base and Shift must have the same number of components. The
number of components and bit width of the type of Base must be the same
as in Result Type.
Shift is consumed as an unsigned integer. The result is undefined if
Shift is greater than or equal to the bit width of the components of
Base.
Results are computed per component.
### Custom assembly form
``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
shift-right-logical-op ::= ssa-id `=` `spv.ShiftRightLogical`
ssa-use `,` ssa-use `:`
integer-scalar-vector-type `,`
integer-scalar-vector-type
```
For example:
```
%2 = spv.ShiftRightLogical %0, %1 : i32, i16
%5 = spv.ShiftRightLogical %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
}
// -----
def SPV_NotOp : SPV_BitUnaryOp<"Not", []> {
let summary = "Complement the bits of Operand.";
let description = [{
Results are computed per component, and within each component, per bit.
Result Type must be a scalar or vector of integer type.
Operands type must be a scalar or vector of integer type. It must
have the same number of components as Result Type. The component width
must equal the component width in Result Type.
### Custom assembly form
``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
not-op ::= ssa-id `=` `spv.BitNot` ssa-use `:` integer-scalar-vector-type
```
For example:
```
%2 = spv.Not %0 : i32
%3 = spv.Not %1 : vector<4xi32>
```
}];
}

View File

@ -449,6 +449,40 @@ static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
printer << " : " << logicalOp->getOperand(0)->getType();
}
static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
SmallVector<OpAsmParser::OperandType, 2> operandInfo;
Type baseType;
Type shiftType;
auto loc = parser.getCurrentLocation();
if (parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
parser.parseType(baseType) || parser.parseComma() ||
parser.parseType(shiftType) ||
parser.resolveOperands(operandInfo, {baseType, shiftType}, loc,
state.operands)) {
return failure();
}
state.addTypes(baseType);
return success();
}
static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
Value *base = op->getOperand(0);
Value *shift = op->getOperand(1);
printer << op->getName() << ' ' << *base << ", " << *shift << " : "
<< base->getType() << ", " << shift->getType();
}
static LogicalResult verifyShiftOp(Operation *op) {
if (op->getOperand(0)->getType() != op->getResult(0)->getType()) {
return op->emitError("expected the same type for the first operand and "
"result, but provided ")
<< op->getOperand(0)->getType() << " and "
<< op->getResult(0)->getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// spv.AccessChainOp
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,34 @@
// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
spv.module "Logical" "GLSL450" {
func @bitcount(%arg: i32) -> i32 {
// CHECK: spv.BitCount {{%.*}} : i32
%0 = spv.BitCount %arg : i32
spv.ReturnValue %0 : i32
}
func @bitreverse(%arg: i32) -> i32 {
// CHECK: spv.BitReverse {{%.*}} : i32
%0 = spv.BitReverse %arg : i32
spv.ReturnValue %0 : i32
}
func @not(%arg: i32) -> i32 {
// CHECK: spv.Not {{%.*}} : i32
%0 = spv.Not %arg : i32
spv.ReturnValue %0 : i32
}
func @shift_left_logical(%arg0: i32, %arg1 : i16) -> i32 {
// CHECK: {{%.*}} = spv.ShiftLeftLogical {{%.*}}, {{%.*}} : i32, i16
%0 = spv.ShiftLeftLogical %arg0, %arg1: i32, i16
spv.ReturnValue %0 : i32
}
func @shift_right_aritmethic(%arg0: vector<4xi32>, %arg1 : vector<4xi8>) -> vector<4xi32> {
// CHECK: {{%.*}} = spv.ShiftRightArithmetic {{%.*}}, {{%.*}} : vector<4xi32>, vector<4xi8>
%0 = spv.ShiftRightArithmetic %arg0, %arg1: vector<4xi32>, vector<4xi8>
spv.ReturnValue %0 : vector<4xi32>
}
func @shift_right_logical(%arg0: vector<2xi32>, %arg1 : vector<2xi8>) -> vector<2xi32> {
// CHECK: {{%.*}} = spv.ShiftRightLogical {{%.*}}, {{%.*}} : vector<2xi32>, vector<2xi8>
%0 = spv.ShiftRightLogical %arg0, %arg1: vector<2xi32>, vector<2xi8>
spv.ReturnValue %0 : vector<2xi32>
}
}

View File

@ -202,6 +202,30 @@ func @cast3(%arg0 : i64) {
// -----
//===----------------------------------------------------------------------===//
// spv.BitCount
//===----------------------------------------------------------------------===//
func @bitcount(%arg: i32) -> i32 {
// CHECK: spv.BitCount {{%.*}} : i32
%0 = spv.BitCount %arg : i32
spv.ReturnValue %0 : i32
}
// -----
//===----------------------------------------------------------------------===//
// spv.BitReverse
//===----------------------------------------------------------------------===//
func @bitreverse(%arg: i32) -> i32 {
// CHECK: spv.BitReverse {{%.*}} : i32
%0 = spv.BitReverse %arg : i32
spv.ReturnValue %0 : i32
}
// -----
//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//
@ -857,6 +881,18 @@ func @memory_barrier_2() -> () {
// -----
//===----------------------------------------------------------------------===//
// spv.Not
//===----------------------------------------------------------------------===//
func @not(%arg: i32) -> i32 {
// CHECK: spv.Not {{%.*}} : i32
%0 = spv.Not %arg : i32
spv.ReturnValue %0 : i32
}
// -----
//===----------------------------------------------------------------------===//
// spv.SelectOp
//===----------------------------------------------------------------------===//
@ -961,6 +997,50 @@ func @select_op(%arg1: vector<4xi1>) -> () {
// -----
//===----------------------------------------------------------------------===//
// spv.ShiftLeftLogical
//===----------------------------------------------------------------------===//
func @shift_left_logical(%arg0: i32, %arg1 : i16) -> i32 {
// CHECK: {{%.*}} = spv.ShiftLeftLogical {{%.*}}, {{%.*}} : i32, i16
%0 = spv.ShiftLeftLogical %arg0, %arg1: i32, i16
spv.ReturnValue %0 : i32
}
// -----
func @shift_left_logical_invalid_result_type(%arg0: i32, %arg1 : i16) -> i16 {
// expected-error @+1 {{expected the same type for the first operand and result, but provided 'i32' and 'i16'}}
%0 = "spv.ShiftLeftLogical" (%arg0, %arg1) : (i32, i16) -> (i16)
spv.ReturnValue %0 : i16
}
// -----
//===----------------------------------------------------------------------===//
// spv.ShiftRightArithmetic
//===----------------------------------------------------------------------===//
func @shift_right_aritmethic(%arg0: vector<4xi32>, %arg1 : vector<4xi8>) -> vector<4xi32> {
// CHECK: {{%.*}} = spv.ShiftRightArithmetic {{%.*}}, {{%.*}} : vector<4xi32>, vector<4xi8>
%0 = spv.ShiftRightArithmetic %arg0, %arg1: vector<4xi32>, vector<4xi8>
spv.ReturnValue %0 : vector<4xi32>
}
// -----
//===----------------------------------------------------------------------===//
// spv.ShiftRightLogical
//===----------------------------------------------------------------------===//
func @shift_right_logical(%arg0: vector<2xi32>, %arg1 : vector<2xi8>) -> vector<2xi32> {
// CHECK: {{%.*}} = spv.ShiftRightLogical {{%.*}}, {{%.*}} : vector<2xi32>, vector<2xi8>
%0 = spv.ShiftRightLogical %arg0, %arg1: vector<2xi32>, vector<2xi8>
spv.ReturnValue %0 : vector<2xi32>
}
// -----
//===----------------------------------------------------------------------===//
// spv.StoreOp
//===----------------------------------------------------------------------===//