From 6f0e65441c5cd018e8b0ad5c340435c0ee57eea1 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Wed, 25 Sep 2019 19:01:18 -0700 Subject: [PATCH] Add spv.Bitcast operation to SPIR-V dialect Support the OpBitcast instruction of SPIR-V using the spv.Bitcast operation. The semantics implemented in the dialect differ from the SPIR-V spec in that the dialect does not allow conversion to/from pointer types from/to non-pointer types. PiperOrigin-RevId: 271255957 --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 12 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 57 ++++++++++++- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 73 +++++++++++++++++ .../Dialect/SPIRV/Serialization/cast-ops.mlir | 9 ++ mlir/test/Dialect/SPIRV/ops.mlir | 82 +++++++++++++++++++ 5 files changed, 227 insertions(+), 6 deletions(-) create mode 100644 mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 081c6dfbd2f6..1440f75026b8 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -116,6 +116,7 @@ def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; +def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>; def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>; def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>; def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>; @@ -179,11 +180,11 @@ def SPV_OpcodeAttr : SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpFNegate, - SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, - SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, - SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, - SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, + SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpBitcast, + SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, + SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, + SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, + SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, @@ -241,6 +242,7 @@ class SPV_ScalarOrVectorOf : AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>; def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>; +def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>; // TODO(antiagainst): Use a more appropriate way to model optional operands class SPV_Optional : Variadic; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 9614989e520b..eb4a876d7743 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -126,6 +126,58 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> { // ----- +def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> { + let summary = "Bit pattern-preserving type conversion."; + + let description = [{ + Result Type must be an OpTypePointer, or a scalar or vector of + numerical-type. + + Operand must have a type of OpTypePointer, or a scalar or vector of + numerical-type. It must be a different type than Result Type. + + If either Result Type or Operand is a pointer, the other must be a + pointer (diverges from the SPIR-V spec). + + If Result Type has a different number of components than Operand, the + total number of bits in Result Type must equal the total number of bits + in Operand. Let L be the type, either Result Type or Operand’s type, + that has the larger number of components. Let S be the other type, with + the smaller number of components. The number of components in L must be + an integer multiple of the number of components in S. The first + component (that is, the only or lowest-numbered component) of S maps to + the first components of L, and so on, up to the last component of S + mapping to the last components of L. Within this mapping, any single + component of S (mapping to multiple components of L) maps its lower- + ordered bits to the lower-numbered components of L. + + ### Custom assembly form + + ``` {.ebnf} + bitcast-op ::= ssa-id `=` `spv.Bitcast` ssa-use + `from` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.Bitcast %0 from f32 to i32 + %1 = spv.Bitcast %0 from vector<2xf32> to i64 + %1 = spv.Bitcast %0 from !spv.ptr to !spv.ptr + ``` + }]; + + let arguments = (ins + SPV_ScalarOrVectorOrPtr:$operand + ); + + let results = (outs + SPV_ScalarOrVectorOrPtr:$result + ); +} + +// ----- + def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> { let summary = "Extract a part of a composite object."; @@ -339,7 +391,10 @@ def SPV_MemoryBarrierOp : SPV_Op<"MemoryBarrier", []> { Ensures that memory accesses issued before this instruction will be observed before memory accesses issued after this instruction. This control is ensured only for memory accesses issued by this invocation - and observed by another invocation executing within Memory scope. + and observed by another invocation executing within Memory scope. If the + Vulkan memory model is declared, this ordering only applies to memory + accesses that use the NonPrivatePointer memory operand or + NonPrivateTexel image operand. Semantics declares what kind of memory is being controlled and what kind of control to apply. diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 1d637f8a30ee..408d365c250c 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -316,6 +316,24 @@ static Attribute extractCompositeElement(Attribute composite, return {}; } +// Get bit width of types. +static unsigned getBitWidth(Type type) { + if (type.isa()) { + // Just return 64 bits for pointer types for now. + // TODO: Make sure not caller relies on the actual pointer width value. + return 64; + } + if (type.isIntOrFloat()) { + return type.getIntOrFloatBitWidth(); + } + if (auto vectorType = type.dyn_cast()) { + assert(vectorType.getElementType().isIntOrFloat()); + return vectorType.getNumElements() * + vectorType.getElementType().getIntOrFloatBitWidth(); + } + llvm_unreachable("unhandled bit width computation for type"); +} + //===----------------------------------------------------------------------===// // Common parsers and printers //===----------------------------------------------------------------------===// @@ -548,6 +566,61 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) { return success(); } +//===----------------------------------------------------------------------===// +// spv.BitcastOp +//===----------------------------------------------------------------------===// + +static ParseResult parseBitcastOp(OpAsmParser &parser, OperationState &state) { + OpAsmParser::OperandType operandInfo; + Type operandType, resultType; + if (parser.parseOperand(operandInfo) || parser.parseKeyword("from") || + parser.parseType(operandType) || parser.parseKeyword("to") || + parser.parseType(resultType)) { + return failure(); + } + if (parser.resolveOperands(operandInfo, operandType, state.operands)) { + return failure(); + } + state.addTypes(resultType); + return success(); +} + +static void print(spirv::BitcastOp bitcastOp, OpAsmPrinter &printer) { + printer << spirv::BitcastOp::getOperationName() << ' '; + printer.printOperand(bitcastOp.operand()); + printer << " from " << bitcastOp.operand()->getType() << " to " + << bitcastOp.result()->getType(); +} + +static LogicalResult verify(spirv::BitcastOp bitcastOp) { + // TODO: The SPIR-V spec validation rules are different for different + // versions. + auto operandType = bitcastOp.operand()->getType(); + auto resultType = bitcastOp.result()->getType(); + if (operandType == resultType) { + return bitcastOp.emitError( + "result type must be different from operand type"); + } + if (operandType.isa() && + !resultType.isa()) { + return bitcastOp.emitError( + "unhandled bit cast conversion from pointer type to non-pointer type"); + } + if (!operandType.isa() && + resultType.isa()) { + return bitcastOp.emitError( + "unhandled bit cast conversion from non-pointer type to pointer type"); + } + auto operandBitWidth = getBitWidth(operandType); + auto resultBitWidth = getBitWidth(resultType); + if (operandBitWidth != resultBitWidth) { + return bitcastOp.emitOpError("mismatch in result type bitwidth ") + << resultBitWidth << " and operand type bitwidth " + << operandBitWidth; + } + return success(); +} + //===----------------------------------------------------------------------===// // spv.BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir new file mode 100644 index 000000000000..5e488f7dc775 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s + +spv.module "Logical" "GLSL450" { + func @fmul(%arg0 : f32) { + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from f32 to i32 + %0 = spv.Bitcast %arg0 from f32 to i32 + spv.Return + } +} diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index a78d088c2913..8c4b0fa2c58b 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -120,6 +120,88 @@ func @access_chain_invalid_accessing_type(%index0 : i32) -> () { // ----- +//===----------------------------------------------------------------------===// +// spv.Bitcast +//===----------------------------------------------------------------------===// + +func @cast1(%arg0 : f32) { + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from f32 to i32 + %0 = spv.Bitcast %arg0 from f32 to i32 + return +} + +func @cast2(%arg0 : vector<2xf32>) { + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from vector<2xf32> to vector<2xi32> + %0 = spv.Bitcast %arg0 from vector<2xf32> to vector<2xi32> + return +} + +func @cast3(%arg0 : vector<2xf32>) { + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from vector<2xf32> to i64 + %0 = spv.Bitcast %arg0 from vector<2xf32> to i64 + return +} + +func @cast4(%arg0 : !spv.ptr) { + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from !spv.ptr to !spv.ptr + %0 = spv.Bitcast %arg0 from !spv.ptr to !spv.ptr + return +} + +func @cast5(%arg0 : !spv.ptr) { + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from !spv.ptr to !spv.ptr, Function> + %0 = spv.Bitcast %arg0 from !spv.ptr to !spv.ptr, Function> + return +} + +func @cast6(%arg0 : vector<4xf32>) { + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from vector<4xf32> to vector<2xi64> + %0 = spv.Bitcast %arg0 from vector<4xf32> to vector<2xi64> + return +} + +// ----- + +func @cast1(%arg0 : f32) { + // expected-error @+1 {{result type must be different from operand type}} + %0 = spv.Bitcast %arg0 from f32 to f32 + return +} + +// ----- + +func @cast1(%arg0 : f32) { + // expected-error @+1 {{mismatch in result type bitwidth 64 and operand type bitwidth 32}} + %0 = spv.Bitcast %arg0 from f32 to i64 + return +} + +// ----- + +func @cast1(%arg0 : vector<2xf32>) { + // expected-error @+1 {{mismatch in result type bitwidth 96 and operand type bitwidth 64}} + %0 = spv.Bitcast %arg0 from vector<2xf32> to vector<3xf32> + return +} + +// ----- + +func @cast3(%arg0 : !spv.ptr) { + // expected-error @+1 {{unhandled bit cast conversion from pointer type to non-pointer type}} + %0 = spv.Bitcast %arg0 from !spv.ptr to i64 + return +} + +// ----- + +func @cast3(%arg0 : i64) { + // expected-error @+1 {{unhandled bit cast conversion from non-pointer type to pointer type}} + %0 = spv.Bitcast %arg0 from i64 to !spv.ptr + return +} + +// ----- + //===----------------------------------------------------------------------===// // spv.CompositeExtractOp //===----------------------------------------------------------------------===//