[mlir][spirv] Add (InBounds)PtrAccessChain ops

Differential Revision: https://reviews.llvm.org/D108070
This commit is contained in:
Butygin 2021-08-14 11:57:02 +03:00
parent ffe58de393
commit ddc3d51d58
4 changed files with 278 additions and 18 deletions

View File

@ -3194,6 +3194,8 @@ def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
def SPV_OC_OpCopyMemory : I32EnumAttrCase<"OpCopyMemory", 63>;
def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
def SPV_OC_OpPtrAccessChain : I32EnumAttrCase<"OpPtrAccessChain", 67>;
def SPV_OC_OpInBoundsPtrAccessChain : I32EnumAttrCase<"OpInBoundsPtrAccessChain", 70>;
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
def SPV_OC_OpVectorExtractDynamic : I32EnumAttrCase<"OpVectorExtractDynamic", 77>;
@ -3340,10 +3342,10 @@ def SPV_OpcodeAttr :
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOp,
SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd,
SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic,
SPV_OC_OpVectorInsertDynamic, SPV_OC_OpVectorShuffle,
SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpPtrAccessChain,
SPV_OC_OpInBoundsPtrAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
SPV_OC_OpVectorExtractDynamic, SPV_OC_OpVectorInsertDynamic,
SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpImageDrefGather,
SPV_OC_OpImage, SPV_OC_OpImageQuerySize, SPV_OC_OpConvertFToU,
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,

View File

@ -137,6 +137,55 @@ def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
// -----
def SPV_InBoundsPtrAccessChainOp : SPV_Op<"InBoundsPtrAccessChain", [NoSideEffect]> {
let summary = [{
Has the same semantics as OpPtrAccessChain, with the addition that the
resulting pointer is known to point within the base object.
}];
let description = [{
<!-- End of AutoGen section -->
```
access-chain-op ::= ssa-id `=` `spv.InBoundsPtrAccessChain` ssa-use
`[` ssa-use (',' ssa-use)* `]`
`:` pointer-type
```mlir
#### Example:
```
func @inbounds_ptr_access_chain(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
%0 = spv.InBoundsPtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
...
}
```
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_Addresses]>
];
let arguments = (ins
SPV_AnyPtr:$base_ptr,
SPV_Integer:$element,
Variadic<SPV_Integer>:$indices
);
let results = (outs
SPV_AnyPtr:$result
);
let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
}
// -----
def SPV_LoadOp : SPV_Op<"Load", []> {
let summary = "Load through a pointer.";
@ -191,6 +240,78 @@ def SPV_LoadOp : SPV_Op<"Load", []> {
// -----
def SPV_PtrAccessChainOp : SPV_Op<"PtrAccessChain", [NoSideEffect]> {
let summary = [{
Has the same semantics as OpAccessChain, with the addition of the
Element operand.
}];
let description = [{
Element is used to do an initial dereference of Base: Base is treated as
the address of an element in an array, and a new element address is
computed from Base and Element to become the OpAccessChain Base to
dereference as per OpAccessChain. This computed Base has the same type
as the originating Base.
To compute the new element address, Element is treated as a signed count
of elements E, relative to the original Base element B, and the address
of element B + E is computed using enough precision to avoid overflow
and underflow. For objects in the Uniform, StorageBuffer, or
PushConstant storage classes, the element's address or location is
calculated using a stride, which will be the Base-type's Array Stride if
the Base type is decorated with ArrayStride. For all other objects, the
implementation calculates the element's address or location.
With one exception, undefined behavior results when B + E is not an
element in the same array (same innermost array, if array types are
nested) as B. The exception being when B + E = L, where L is the length
of the array: the address computation for element L is done with the
same stride as any other B + E computation that stays within the array.
Note: If Base is typed to be a pointer to an array and the desired
operation is to select an element of that array, OpAccessChain should be
directly used, as its first Index selects the array element.
<!-- End of AutoGen section -->
```
[access-chain-op ::= ssa-id `=` `spv.PtrAccessChain` ssa-use
`[` ssa-use (',' ssa-use)* `]`
`:` pointer-type
```mlir
#### Example:
```
func @ptr_access_chain(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
%0 = spv.PtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
...
}
```
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[]>,
Capability<[SPV_C_Addresses, SPV_C_PhysicalStorageBufferAddresses, SPV_C_VariablePointers, SPV_C_VariablePointersStorageBuffer]>
];
let arguments = (ins
SPV_AnyPtr:$base_ptr,
SPV_Integer:$element,
Variadic<SPV_Integer>:$indices
);
let results = (outs
SPV_AnyPtr:$result
);
let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
}
// -----
def SPV_StoreOp : SPV_Op<"Store", []> {
let summary = "Store through a pointer.";

View File

@ -1019,37 +1019,41 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
return success();
}
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
<< '[' << op.indices() << "] : " << op.base_ptr().getType() << ", "
<< op.indices().getTypes();
template <typename Op>
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
printer << Op::getOperationName() << ' ' << op.base_ptr() << '[' << indices
<< "] : " << op.base_ptr().getType() << ", " << indices.getTypes();
}
static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
accessChainOp.indices().end());
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
printAccessChain(op, op.indices(), printer);
}
template <typename Op>
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
indices, accessChainOp.getLoc());
if (!resultType) {
if (!resultType)
return failure();
}
auto providedResultType =
accessChainOp.getType().dyn_cast<spirv::PointerType>();
if (!providedResultType) {
accessChainOp.getType().template dyn_cast<spirv::PointerType>();
if (!providedResultType)
return accessChainOp.emitOpError(
"result type must be a pointer, but provided")
<< providedResultType;
}
if (resultType != providedResultType) {
if (resultType != providedResultType)
return accessChainOp.emitOpError("invalid result type: expected ")
<< resultType << ", but provided " << providedResultType;
}
return success();
}
static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
return verifyAccessChain(accessChainOp, accessChainOp.indices());
}
//===----------------------------------------------------------------------===//
// spv.mlir.addressof
//===----------------------------------------------------------------------===//
@ -3770,6 +3774,109 @@ static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) {
return success();
}
static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
OpAsmParser &parser,
OperationState &state) {
OpAsmParser::OperandType ptrInfo;
SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
Type type;
auto loc = parser.getCurrentLocation();
SmallVector<Type, 4> indicesTypes;
if (parser.parseOperand(ptrInfo) ||
parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
parser.parseColonType(type) ||
parser.resolveOperand(ptrInfo, type, state.operands))
return failure();
// Check that the provided indices list is not empty before parsing their
// type list.
if (indicesInfo.empty())
return emitError(state.location) << opName << " expected element";
if (parser.parseComma() || parser.parseTypeList(indicesTypes))
return failure();
// Check that the indices types list is not empty and that it has a one-to-one
// mapping to the provided indices.
if (indicesTypes.size() != indicesInfo.size())
return emitError(state.location)
<< opName
<< " indices types' count must be equal to indices info count";
if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
return failure();
auto resultType = getElementPtrType(
type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
if (!resultType)
return failure();
state.addTypes(resultType);
return success();
}
template <typename Op>
static auto concatElemAndIndices(Op op) {
SmallVector<Value> ret(op.indices().size() + 1);
ret[0] = op.element();
llvm::copy(op.indices(), ret.begin() + 1);
return ret;
}
//===----------------------------------------------------------------------===//
// spv.InBoundsPtrAccessChainOp
//===----------------------------------------------------------------------===//
void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
OperationState &state,
Value basePtr, Value element,
ValueRange indices) {
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
assert(type && "Unable to deduce return type based on basePtr and indices");
build(builder, state, type, basePtr, element, indices);
}
static ParseResult parseInBoundsPtrAccessChainOp(OpAsmParser &parser,
OperationState &state) {
return parsePtrAccessChainOpImpl(
spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state);
}
static void print(spirv::InBoundsPtrAccessChainOp op, OpAsmPrinter &printer) {
printAccessChain(op, concatElemAndIndices(op), printer);
}
static LogicalResult verify(spirv::InBoundsPtrAccessChainOp accessChainOp) {
return verifyAccessChain(accessChainOp, accessChainOp.indices());
}
//===----------------------------------------------------------------------===//
// spv.PtrAccessChainOp
//===----------------------------------------------------------------------===//
void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
Value basePtr, Value element,
ValueRange indices) {
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
assert(type && "Unable to deduce return type based on basePtr and indices");
build(builder, state, type, basePtr, element, indices);
}
static ParseResult parsePtrAccessChainOp(OpAsmParser &parser,
OperationState &state) {
return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
parser, state);
}
static void print(spirv::PtrAccessChainOp op, OpAsmPrinter &printer) {
printAccessChain(op, concatElemAndIndices(op), printer);
}
static LogicalResult verify(spirv::PtrAccessChainOp accessChainOp) {
return verifyAccessChain(accessChainOp, accessChainOp.indices());
}
namespace mlir {
namespace spirv {

View File

@ -628,3 +628,33 @@ func @copy_memory_print_maa() {
spv.Return
}
// -----
//===----------------------------------------------------------------------===//
// spv.PtrAccessChain
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @ptr_access_chain1(
// CHECK-SAME: %[[ARG0:.*]]: !spv.ptr<f32, CrossWorkgroup>,
// CHECK-SAME: %[[ARG1:.*]]: i64)
// CHECK: spv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spv.ptr<f32, CrossWorkgroup>, i64
func @ptr_access_chain1(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
%0 = spv.PtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
return
}
// -----
//===----------------------------------------------------------------------===//
// spv.InBoundsPtrAccessChain
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @inbounds_ptr_access_chain1(
// CHECK-SAME: %[[ARG0:.*]]: !spv.ptr<f32, CrossWorkgroup>,
// CHECK-SAME: %[[ARG1:.*]]: i64)
// CHECK: spv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spv.ptr<f32, CrossWorkgroup>, i64
func @inbounds_ptr_access_chain1(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
%0 = spv.InBoundsPtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
return
}