forked from OSchip/llvm-project
[mlir][spirv] Add (InBounds)PtrAccessChain ops
Differential Revision: https://reviews.llvm.org/D108070
This commit is contained in:
parent
ffe58de393
commit
ddc3d51d58
|
@ -3194,6 +3194,8 @@ def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
|
|||
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
|
||||
def SPV_OC_OpCopyMemory : I32EnumAttrCase<"OpCopyMemory", 63>;
|
||||
def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
|
||||
def SPV_OC_OpPtrAccessChain : I32EnumAttrCase<"OpPtrAccessChain", 67>;
|
||||
def SPV_OC_OpInBoundsPtrAccessChain : I32EnumAttrCase<"OpInBoundsPtrAccessChain", 70>;
|
||||
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
|
||||
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
|
||||
def SPV_OC_OpVectorExtractDynamic : I32EnumAttrCase<"OpVectorExtractDynamic", 77>;
|
||||
|
@ -3340,10 +3342,10 @@ def SPV_OpcodeAttr :
|
|||
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOp,
|
||||
SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd,
|
||||
SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
|
||||
SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
|
||||
SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic,
|
||||
SPV_OC_OpVectorInsertDynamic, SPV_OC_OpVectorShuffle,
|
||||
SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
|
||||
SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpPtrAccessChain,
|
||||
SPV_OC_OpInBoundsPtrAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
|
||||
SPV_OC_OpVectorExtractDynamic, SPV_OC_OpVectorInsertDynamic,
|
||||
SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
|
||||
SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpImageDrefGather,
|
||||
SPV_OC_OpImage, SPV_OC_OpImageQuerySize, SPV_OC_OpConvertFToU,
|
||||
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
|
||||
|
|
|
@ -137,6 +137,55 @@ def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_InBoundsPtrAccessChainOp : SPV_Op<"InBoundsPtrAccessChain", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Has the same semantics as OpPtrAccessChain, with the addition that the
|
||||
resulting pointer is known to point within the base object.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
access-chain-op ::= ssa-id `=` `spv.InBoundsPtrAccessChain` ssa-use
|
||||
`[` ssa-use (',' ssa-use)* `]`
|
||||
`:` pointer-type
|
||||
```mlir
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
func @inbounds_ptr_access_chain(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
|
||||
%0 = spv.InBoundsPtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
|
||||
...
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
Extension<[]>,
|
||||
Capability<[SPV_C_Addresses]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_AnyPtr:$base_ptr,
|
||||
SPV_Integer:$element,
|
||||
Variadic<SPV_Integer>:$indices
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_AnyPtr:$result
|
||||
);
|
||||
|
||||
let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_LoadOp : SPV_Op<"Load", []> {
|
||||
let summary = "Load through a pointer.";
|
||||
|
||||
|
@ -191,6 +240,78 @@ def SPV_LoadOp : SPV_Op<"Load", []> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_PtrAccessChainOp : SPV_Op<"PtrAccessChain", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Has the same semantics as OpAccessChain, with the addition of the
|
||||
Element operand.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Element is used to do an initial dereference of Base: Base is treated as
|
||||
the address of an element in an array, and a new element address is
|
||||
computed from Base and Element to become the OpAccessChain Base to
|
||||
dereference as per OpAccessChain. This computed Base has the same type
|
||||
as the originating Base.
|
||||
|
||||
To compute the new element address, Element is treated as a signed count
|
||||
of elements E, relative to the original Base element B, and the address
|
||||
of element B + E is computed using enough precision to avoid overflow
|
||||
and underflow. For objects in the Uniform, StorageBuffer, or
|
||||
PushConstant storage classes, the element's address or location is
|
||||
calculated using a stride, which will be the Base-type's Array Stride if
|
||||
the Base type is decorated with ArrayStride. For all other objects, the
|
||||
implementation calculates the element's address or location.
|
||||
|
||||
With one exception, undefined behavior results when B + E is not an
|
||||
element in the same array (same innermost array, if array types are
|
||||
nested) as B. The exception being when B + E = L, where L is the length
|
||||
of the array: the address computation for element L is done with the
|
||||
same stride as any other B + E computation that stays within the array.
|
||||
|
||||
Note: If Base is typed to be a pointer to an array and the desired
|
||||
operation is to select an element of that array, OpAccessChain should be
|
||||
directly used, as its first Index selects the array element.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
```
|
||||
[access-chain-op ::= ssa-id `=` `spv.PtrAccessChain` ssa-use
|
||||
`[` ssa-use (',' ssa-use)* `]`
|
||||
`:` pointer-type
|
||||
```mlir
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
func @ptr_access_chain(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
|
||||
%0 = spv.PtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
|
||||
...
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPV_V_1_0>,
|
||||
MaxVersion<SPV_V_1_5>,
|
||||
Extension<[]>,
|
||||
Capability<[SPV_C_Addresses, SPV_C_PhysicalStorageBufferAddresses, SPV_C_VariablePointers, SPV_C_VariablePointersStorageBuffer]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_AnyPtr:$base_ptr,
|
||||
SPV_Integer:$element,
|
||||
Variadic<SPV_Integer>:$indices
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_AnyPtr:$result
|
||||
);
|
||||
|
||||
let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_StoreOp : SPV_Op<"Store", []> {
|
||||
let summary = "Store through a pointer.";
|
||||
|
||||
|
|
|
@ -1019,37 +1019,41 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
|
||||
printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
|
||||
<< '[' << op.indices() << "] : " << op.base_ptr().getType() << ", "
|
||||
<< op.indices().getTypes();
|
||||
template <typename Op>
|
||||
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
|
||||
printer << Op::getOperationName() << ' ' << op.base_ptr() << '[' << indices
|
||||
<< "] : " << op.base_ptr().getType() << ", " << indices.getTypes();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
|
||||
SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
|
||||
accessChainOp.indices().end());
|
||||
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
|
||||
printAccessChain(op, op.indices(), printer);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
|
||||
auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
|
||||
indices, accessChainOp.getLoc());
|
||||
if (!resultType) {
|
||||
if (!resultType)
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto providedResultType =
|
||||
accessChainOp.getType().dyn_cast<spirv::PointerType>();
|
||||
if (!providedResultType) {
|
||||
accessChainOp.getType().template dyn_cast<spirv::PointerType>();
|
||||
if (!providedResultType)
|
||||
return accessChainOp.emitOpError(
|
||||
"result type must be a pointer, but provided")
|
||||
<< providedResultType;
|
||||
}
|
||||
|
||||
if (resultType != providedResultType) {
|
||||
if (resultType != providedResultType)
|
||||
return accessChainOp.emitOpError("invalid result type: expected ")
|
||||
<< resultType << ", but provided " << providedResultType;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
|
||||
return verifyAccessChain(accessChainOp, accessChainOp.indices());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.mlir.addressof
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3770,6 +3774,109 @@ static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
|
||||
OpAsmParser &parser,
|
||||
OperationState &state) {
|
||||
OpAsmParser::OperandType ptrInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
|
||||
Type type;
|
||||
auto loc = parser.getCurrentLocation();
|
||||
SmallVector<Type, 4> indicesTypes;
|
||||
|
||||
if (parser.parseOperand(ptrInfo) ||
|
||||
parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.resolveOperand(ptrInfo, type, state.operands))
|
||||
return failure();
|
||||
|
||||
// Check that the provided indices list is not empty before parsing their
|
||||
// type list.
|
||||
if (indicesInfo.empty())
|
||||
return emitError(state.location) << opName << " expected element";
|
||||
|
||||
if (parser.parseComma() || parser.parseTypeList(indicesTypes))
|
||||
return failure();
|
||||
|
||||
// Check that the indices types list is not empty and that it has a one-to-one
|
||||
// mapping to the provided indices.
|
||||
if (indicesTypes.size() != indicesInfo.size())
|
||||
return emitError(state.location)
|
||||
<< opName
|
||||
<< " indices types' count must be equal to indices info count";
|
||||
|
||||
if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
|
||||
return failure();
|
||||
|
||||
auto resultType = getElementPtrType(
|
||||
type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
|
||||
if (!resultType)
|
||||
return failure();
|
||||
|
||||
state.addTypes(resultType);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
static auto concatElemAndIndices(Op op) {
|
||||
SmallVector<Value> ret(op.indices().size() + 1);
|
||||
ret[0] = op.element();
|
||||
llvm::copy(op.indices(), ret.begin() + 1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.InBoundsPtrAccessChainOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
|
||||
OperationState &state,
|
||||
Value basePtr, Value element,
|
||||
ValueRange indices) {
|
||||
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
|
||||
assert(type && "Unable to deduce return type based on basePtr and indices");
|
||||
build(builder, state, type, basePtr, element, indices);
|
||||
}
|
||||
|
||||
static ParseResult parseInBoundsPtrAccessChainOp(OpAsmParser &parser,
|
||||
OperationState &state) {
|
||||
return parsePtrAccessChainOpImpl(
|
||||
spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state);
|
||||
}
|
||||
|
||||
static void print(spirv::InBoundsPtrAccessChainOp op, OpAsmPrinter &printer) {
|
||||
printAccessChain(op, concatElemAndIndices(op), printer);
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::InBoundsPtrAccessChainOp accessChainOp) {
|
||||
return verifyAccessChain(accessChainOp, accessChainOp.indices());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.PtrAccessChainOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
|
||||
Value basePtr, Value element,
|
||||
ValueRange indices) {
|
||||
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
|
||||
assert(type && "Unable to deduce return type based on basePtr and indices");
|
||||
build(builder, state, type, basePtr, element, indices);
|
||||
}
|
||||
|
||||
static ParseResult parsePtrAccessChainOp(OpAsmParser &parser,
|
||||
OperationState &state) {
|
||||
return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
|
||||
parser, state);
|
||||
}
|
||||
|
||||
static void print(spirv::PtrAccessChainOp op, OpAsmPrinter &printer) {
|
||||
printAccessChain(op, concatElemAndIndices(op), printer);
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::PtrAccessChainOp accessChainOp) {
|
||||
return verifyAccessChain(accessChainOp, accessChainOp.indices());
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace spirv {
|
||||
|
||||
|
|
|
@ -628,3 +628,33 @@ func @copy_memory_print_maa() {
|
|||
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.PtrAccessChain
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @ptr_access_chain1(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !spv.ptr<f32, CrossWorkgroup>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i64)
|
||||
// CHECK: spv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spv.ptr<f32, CrossWorkgroup>, i64
|
||||
func @ptr_access_chain1(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
|
||||
%0 = spv.PtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.InBoundsPtrAccessChain
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @inbounds_ptr_access_chain1(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !spv.ptr<f32, CrossWorkgroup>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i64)
|
||||
// CHECK: spv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spv.ptr<f32, CrossWorkgroup>, i64
|
||||
func @inbounds_ptr_access_chain1(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
|
||||
%0 = spv.InBoundsPtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue