Add spv.Bitcast operation to SPIR-V dialect

Support the OpBitcast instruction of SPIR-V using the spv.Bitcast
operation. The semantics implemented in the dialect differ from the
SPIR-V spec in that the dialect does not allow conversion to/from
pointer types from/to non-pointer types.

PiperOrigin-RevId: 271255957
This commit is contained in:
Mahesh Ravishankar 2019-09-25 19:01:18 -07:00 committed by A. Unique TensorFlower
parent 47a7021cc3
commit 6f0e65441c
5 changed files with 227 additions and 6 deletions

View File

@ -116,6 +116,7 @@ def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>;
def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>;
def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>;
def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>;
@ -179,11 +180,11 @@ def SPV_OpcodeAttr :
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpFNegate,
SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect,
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpBitcast,
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
@ -241,6 +242,7 @@ class SPV_ScalarOrVectorOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>;
def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
// TODO(antiagainst): Use a more appropriate way to model optional operands
class SPV_Optional<Type type> : Variadic<type>;

View File

@ -126,6 +126,58 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
// -----
def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> {
let summary = "Bit pattern-preserving type conversion.";
let description = [{
Result Type must be an OpTypePointer, or a scalar or vector of
numerical-type.
Operand must have a type of OpTypePointer, or a scalar or vector of
numerical-type. It must be a different type than Result Type.
If either Result Type or Operand is a pointer, the other must be a
pointer (diverges from the SPIR-V spec).
If Result Type has a different number of components than Operand, the
total number of bits in Result Type must equal the total number of bits
in Operand. Let L be the type, either Result Type or Operands type,
that has the larger number of components. Let S be the other type, with
the smaller number of components. The number of components in L must be
an integer multiple of the number of components in S. The first
component (that is, the only or lowest-numbered component) of S maps to
the first components of L, and so on, up to the last component of S
mapping to the last components of L. Within this mapping, any single
component of S (mapping to multiple components of L) maps its lower-
ordered bits to the lower-numbered components of L.
### Custom assembly form
``` {.ebnf}
bitcast-op ::= ssa-id `=` `spv.Bitcast` ssa-use
`from` operand-type `to` result-type
```
For example:
```
%1 = spv.Bitcast %0 from f32 to i32
%1 = spv.Bitcast %0 from vector<2xf32> to i64
%1 = spv.Bitcast %0 from !spv.ptr<f32 to Function> to !spv.ptr<i32, Function>
```
}];
let arguments = (ins
SPV_ScalarOrVectorOrPtr:$operand
);
let results = (outs
SPV_ScalarOrVectorOrPtr:$result
);
}
// -----
def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
let summary = "Extract a part of a composite object.";
@ -339,7 +391,10 @@ def SPV_MemoryBarrierOp : SPV_Op<"MemoryBarrier", []> {
Ensures that memory accesses issued before this instruction will be
observed before memory accesses issued after this instruction. This
control is ensured only for memory accesses issued by this invocation
and observed by another invocation executing within Memory scope.
and observed by another invocation executing within Memory scope. If the
Vulkan memory model is declared, this ordering only applies to memory
accesses that use the NonPrivatePointer memory operand or
NonPrivateTexel image operand.
Semantics declares what kind of memory is being controlled and what kind
of control to apply.

View File

@ -316,6 +316,24 @@ static Attribute extractCompositeElement(Attribute composite,
return {};
}
// Get bit width of types.
static unsigned getBitWidth(Type type) {
if (type.isa<spirv::PointerType>()) {
// Just return 64 bits for pointer types for now.
// TODO: Make sure not caller relies on the actual pointer width value.
return 64;
}
if (type.isIntOrFloat()) {
return type.getIntOrFloatBitWidth();
}
if (auto vectorType = type.dyn_cast<VectorType>()) {
assert(vectorType.getElementType().isIntOrFloat());
return vectorType.getNumElements() *
vectorType.getElementType().getIntOrFloatBitWidth();
}
llvm_unreachable("unhandled bit width computation for type");
}
//===----------------------------------------------------------------------===//
// Common parsers and printers
//===----------------------------------------------------------------------===//
@ -548,6 +566,61 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
return success();
}
//===----------------------------------------------------------------------===//
// spv.BitcastOp
//===----------------------------------------------------------------------===//
static ParseResult parseBitcastOp(OpAsmParser &parser, OperationState &state) {
OpAsmParser::OperandType operandInfo;
Type operandType, resultType;
if (parser.parseOperand(operandInfo) || parser.parseKeyword("from") ||
parser.parseType(operandType) || parser.parseKeyword("to") ||
parser.parseType(resultType)) {
return failure();
}
if (parser.resolveOperands(operandInfo, operandType, state.operands)) {
return failure();
}
state.addTypes(resultType);
return success();
}
static void print(spirv::BitcastOp bitcastOp, OpAsmPrinter &printer) {
printer << spirv::BitcastOp::getOperationName() << ' ';
printer.printOperand(bitcastOp.operand());
printer << " from " << bitcastOp.operand()->getType() << " to "
<< bitcastOp.result()->getType();
}
static LogicalResult verify(spirv::BitcastOp bitcastOp) {
// TODO: The SPIR-V spec validation rules are different for different
// versions.
auto operandType = bitcastOp.operand()->getType();
auto resultType = bitcastOp.result()->getType();
if (operandType == resultType) {
return bitcastOp.emitError(
"result type must be different from operand type");
}
if (operandType.isa<spirv::PointerType>() &&
!resultType.isa<spirv::PointerType>()) {
return bitcastOp.emitError(
"unhandled bit cast conversion from pointer type to non-pointer type");
}
if (!operandType.isa<spirv::PointerType>() &&
resultType.isa<spirv::PointerType>()) {
return bitcastOp.emitError(
"unhandled bit cast conversion from non-pointer type to pointer type");
}
auto operandBitWidth = getBitWidth(operandType);
auto resultBitWidth = getBitWidth(resultType);
if (operandBitWidth != resultBitWidth) {
return bitcastOp.emitOpError("mismatch in result type bitwidth ")
<< resultBitWidth << " and operand type bitwidth "
<< operandBitWidth;
}
return success();
}
//===----------------------------------------------------------------------===//
// spv.BranchOp
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,9 @@
// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
spv.module "Logical" "GLSL450" {
func @fmul(%arg0 : f32) {
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from f32 to i32
%0 = spv.Bitcast %arg0 from f32 to i32
spv.Return
}
}

View File

@ -120,6 +120,88 @@ func @access_chain_invalid_accessing_type(%index0 : i32) -> () {
// -----
//===----------------------------------------------------------------------===//
// spv.Bitcast
//===----------------------------------------------------------------------===//
func @cast1(%arg0 : f32) {
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from f32 to i32
%0 = spv.Bitcast %arg0 from f32 to i32
return
}
func @cast2(%arg0 : vector<2xf32>) {
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from vector<2xf32> to vector<2xi32>
%0 = spv.Bitcast %arg0 from vector<2xf32> to vector<2xi32>
return
}
func @cast3(%arg0 : vector<2xf32>) {
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from vector<2xf32> to i64
%0 = spv.Bitcast %arg0 from vector<2xf32> to i64
return
}
func @cast4(%arg0 : !spv.ptr<f32, Function>) {
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from !spv.ptr<f32, Function> to !spv.ptr<i32, Function>
%0 = spv.Bitcast %arg0 from !spv.ptr<f32, Function> to !spv.ptr<i32, Function>
return
}
func @cast5(%arg0 : !spv.ptr<f32, Function>) {
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from !spv.ptr<f32, Function> to !spv.ptr<vector<2xi32>, Function>
%0 = spv.Bitcast %arg0 from !spv.ptr<f32, Function> to !spv.ptr<vector<2xi32>, Function>
return
}
func @cast6(%arg0 : vector<4xf32>) {
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from vector<4xf32> to vector<2xi64>
%0 = spv.Bitcast %arg0 from vector<4xf32> to vector<2xi64>
return
}
// -----
func @cast1(%arg0 : f32) {
// expected-error @+1 {{result type must be different from operand type}}
%0 = spv.Bitcast %arg0 from f32 to f32
return
}
// -----
func @cast1(%arg0 : f32) {
// expected-error @+1 {{mismatch in result type bitwidth 64 and operand type bitwidth 32}}
%0 = spv.Bitcast %arg0 from f32 to i64
return
}
// -----
func @cast1(%arg0 : vector<2xf32>) {
// expected-error @+1 {{mismatch in result type bitwidth 96 and operand type bitwidth 64}}
%0 = spv.Bitcast %arg0 from vector<2xf32> to vector<3xf32>
return
}
// -----
func @cast3(%arg0 : !spv.ptr<f32, Function>) {
// expected-error @+1 {{unhandled bit cast conversion from pointer type to non-pointer type}}
%0 = spv.Bitcast %arg0 from !spv.ptr<f32, Function> to i64
return
}
// -----
func @cast3(%arg0 : i64) {
// expected-error @+1 {{unhandled bit cast conversion from non-pointer type to pointer type}}
%0 = spv.Bitcast %arg0 from i64 to !spv.ptr<f32, Function>
return
}
// -----
//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//