forked from OSchip/llvm-project
Add spv.Bitcast operation to SPIR-V dialect
Support the OpBitcast instruction of SPIR-V using the spv.Bitcast operation. The semantics implemented in the dialect differ from the SPIR-V spec in that the dialect does not allow conversion to/from pointer types from/to non-pointer types. PiperOrigin-RevId: 271255957
This commit is contained in:
parent
47a7021cc3
commit
6f0e65441c
|
@ -116,6 +116,7 @@ def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
|
|||
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
|
||||
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
|
||||
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
|
||||
def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>;
|
||||
def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>;
|
||||
def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>;
|
||||
def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>;
|
||||
|
@ -179,11 +180,11 @@ def SPV_OpcodeAttr :
|
|||
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
|
||||
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
|
||||
SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
|
||||
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpFNegate,
|
||||
SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
|
||||
SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
|
||||
SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect,
|
||||
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
|
||||
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpBitcast,
|
||||
SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
|
||||
SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
|
||||
SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
|
||||
SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
|
||||
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
|
||||
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
|
||||
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
|
||||
|
@ -241,6 +242,7 @@ class SPV_ScalarOrVectorOf<Type type> :
|
|||
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>;
|
||||
|
||||
def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
|
||||
def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
|
||||
|
||||
// TODO(antiagainst): Use a more appropriate way to model optional operands
|
||||
class SPV_Optional<Type type> : Variadic<type>;
|
||||
|
|
|
@ -126,6 +126,58 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> {
|
||||
let summary = "Bit pattern-preserving type conversion.";
|
||||
|
||||
let description = [{
|
||||
Result Type must be an OpTypePointer, or a scalar or vector of
|
||||
numerical-type.
|
||||
|
||||
Operand must have a type of OpTypePointer, or a scalar or vector of
|
||||
numerical-type. It must be a different type than Result Type.
|
||||
|
||||
If either Result Type or Operand is a pointer, the other must be a
|
||||
pointer (diverges from the SPIR-V spec).
|
||||
|
||||
If Result Type has a different number of components than Operand, the
|
||||
total number of bits in Result Type must equal the total number of bits
|
||||
in Operand. Let L be the type, either Result Type or Operand’s type,
|
||||
that has the larger number of components. Let S be the other type, with
|
||||
the smaller number of components. The number of components in L must be
|
||||
an integer multiple of the number of components in S. The first
|
||||
component (that is, the only or lowest-numbered component) of S maps to
|
||||
the first components of L, and so on, up to the last component of S
|
||||
mapping to the last components of L. Within this mapping, any single
|
||||
component of S (mapping to multiple components of L) maps its lower-
|
||||
ordered bits to the lower-numbered components of L.
|
||||
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
bitcast-op ::= ssa-id `=` `spv.Bitcast` ssa-use
|
||||
`from` operand-type `to` result-type
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
%1 = spv.Bitcast %0 from f32 to i32
|
||||
%1 = spv.Bitcast %0 from vector<2xf32> to i64
|
||||
%1 = spv.Bitcast %0 from !spv.ptr<f32 to Function> to !spv.ptr<i32, Function>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_ScalarOrVectorOrPtr:$operand
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_ScalarOrVectorOrPtr:$result
|
||||
);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
|
||||
let summary = "Extract a part of a composite object.";
|
||||
|
||||
|
@ -339,7 +391,10 @@ def SPV_MemoryBarrierOp : SPV_Op<"MemoryBarrier", []> {
|
|||
Ensures that memory accesses issued before this instruction will be
|
||||
observed before memory accesses issued after this instruction. This
|
||||
control is ensured only for memory accesses issued by this invocation
|
||||
and observed by another invocation executing within Memory scope.
|
||||
and observed by another invocation executing within Memory scope. If the
|
||||
Vulkan memory model is declared, this ordering only applies to memory
|
||||
accesses that use the NonPrivatePointer memory operand or
|
||||
NonPrivateTexel image operand.
|
||||
|
||||
Semantics declares what kind of memory is being controlled and what kind
|
||||
of control to apply.
|
||||
|
|
|
@ -316,6 +316,24 @@ static Attribute extractCompositeElement(Attribute composite,
|
|||
return {};
|
||||
}
|
||||
|
||||
// Get bit width of types.
|
||||
static unsigned getBitWidth(Type type) {
|
||||
if (type.isa<spirv::PointerType>()) {
|
||||
// Just return 64 bits for pointer types for now.
|
||||
// TODO: Make sure not caller relies on the actual pointer width value.
|
||||
return 64;
|
||||
}
|
||||
if (type.isIntOrFloat()) {
|
||||
return type.getIntOrFloatBitWidth();
|
||||
}
|
||||
if (auto vectorType = type.dyn_cast<VectorType>()) {
|
||||
assert(vectorType.getElementType().isIntOrFloat());
|
||||
return vectorType.getNumElements() *
|
||||
vectorType.getElementType().getIntOrFloatBitWidth();
|
||||
}
|
||||
llvm_unreachable("unhandled bit width computation for type");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common parsers and printers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -548,6 +566,61 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.BitcastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseBitcastOp(OpAsmParser &parser, OperationState &state) {
|
||||
OpAsmParser::OperandType operandInfo;
|
||||
Type operandType, resultType;
|
||||
if (parser.parseOperand(operandInfo) || parser.parseKeyword("from") ||
|
||||
parser.parseType(operandType) || parser.parseKeyword("to") ||
|
||||
parser.parseType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
if (parser.resolveOperands(operandInfo, operandType, state.operands)) {
|
||||
return failure();
|
||||
}
|
||||
state.addTypes(resultType);
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::BitcastOp bitcastOp, OpAsmPrinter &printer) {
|
||||
printer << spirv::BitcastOp::getOperationName() << ' ';
|
||||
printer.printOperand(bitcastOp.operand());
|
||||
printer << " from " << bitcastOp.operand()->getType() << " to "
|
||||
<< bitcastOp.result()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::BitcastOp bitcastOp) {
|
||||
// TODO: The SPIR-V spec validation rules are different for different
|
||||
// versions.
|
||||
auto operandType = bitcastOp.operand()->getType();
|
||||
auto resultType = bitcastOp.result()->getType();
|
||||
if (operandType == resultType) {
|
||||
return bitcastOp.emitError(
|
||||
"result type must be different from operand type");
|
||||
}
|
||||
if (operandType.isa<spirv::PointerType>() &&
|
||||
!resultType.isa<spirv::PointerType>()) {
|
||||
return bitcastOp.emitError(
|
||||
"unhandled bit cast conversion from pointer type to non-pointer type");
|
||||
}
|
||||
if (!operandType.isa<spirv::PointerType>() &&
|
||||
resultType.isa<spirv::PointerType>()) {
|
||||
return bitcastOp.emitError(
|
||||
"unhandled bit cast conversion from non-pointer type to pointer type");
|
||||
}
|
||||
auto operandBitWidth = getBitWidth(operandType);
|
||||
auto resultBitWidth = getBitWidth(resultType);
|
||||
if (operandBitWidth != resultBitWidth) {
|
||||
return bitcastOp.emitOpError("mismatch in result type bitwidth ")
|
||||
<< resultBitWidth << " and operand type bitwidth "
|
||||
<< operandBitWidth;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.BranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @fmul(%arg0 : f32) {
|
||||
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from f32 to i32
|
||||
%0 = spv.Bitcast %arg0 from f32 to i32
|
||||
spv.Return
|
||||
}
|
||||
}
|
|
@ -120,6 +120,88 @@ func @access_chain_invalid_accessing_type(%index0 : i32) -> () {
|
|||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.Bitcast
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @cast1(%arg0 : f32) {
|
||||
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from f32 to i32
|
||||
%0 = spv.Bitcast %arg0 from f32 to i32
|
||||
return
|
||||
}
|
||||
|
||||
func @cast2(%arg0 : vector<2xf32>) {
|
||||
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from vector<2xf32> to vector<2xi32>
|
||||
%0 = spv.Bitcast %arg0 from vector<2xf32> to vector<2xi32>
|
||||
return
|
||||
}
|
||||
|
||||
func @cast3(%arg0 : vector<2xf32>) {
|
||||
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from vector<2xf32> to i64
|
||||
%0 = spv.Bitcast %arg0 from vector<2xf32> to i64
|
||||
return
|
||||
}
|
||||
|
||||
func @cast4(%arg0 : !spv.ptr<f32, Function>) {
|
||||
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from !spv.ptr<f32, Function> to !spv.ptr<i32, Function>
|
||||
%0 = spv.Bitcast %arg0 from !spv.ptr<f32, Function> to !spv.ptr<i32, Function>
|
||||
return
|
||||
}
|
||||
|
||||
func @cast5(%arg0 : !spv.ptr<f32, Function>) {
|
||||
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from !spv.ptr<f32, Function> to !spv.ptr<vector<2xi32>, Function>
|
||||
%0 = spv.Bitcast %arg0 from !spv.ptr<f32, Function> to !spv.ptr<vector<2xi32>, Function>
|
||||
return
|
||||
}
|
||||
|
||||
func @cast6(%arg0 : vector<4xf32>) {
|
||||
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} from vector<4xf32> to vector<2xi64>
|
||||
%0 = spv.Bitcast %arg0 from vector<4xf32> to vector<2xi64>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cast1(%arg0 : f32) {
|
||||
// expected-error @+1 {{result type must be different from operand type}}
|
||||
%0 = spv.Bitcast %arg0 from f32 to f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cast1(%arg0 : f32) {
|
||||
// expected-error @+1 {{mismatch in result type bitwidth 64 and operand type bitwidth 32}}
|
||||
%0 = spv.Bitcast %arg0 from f32 to i64
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cast1(%arg0 : vector<2xf32>) {
|
||||
// expected-error @+1 {{mismatch in result type bitwidth 96 and operand type bitwidth 64}}
|
||||
%0 = spv.Bitcast %arg0 from vector<2xf32> to vector<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cast3(%arg0 : !spv.ptr<f32, Function>) {
|
||||
// expected-error @+1 {{unhandled bit cast conversion from pointer type to non-pointer type}}
|
||||
%0 = spv.Bitcast %arg0 from !spv.ptr<f32, Function> to i64
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cast3(%arg0 : i64) {
|
||||
// expected-error @+1 {{unhandled bit cast conversion from non-pointer type to pointer type}}
|
||||
%0 = spv.Bitcast %arg0 from i64 to !spv.ptr<f32, Function>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.CompositeExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue