Make SPIR-V lowering infrastructure follow Vulkan SPIR-V validation.

The lowering infrastructure needs to be enhanced to lower into a
spv.Module that is consistent with the SPIR-V spec. The following
changes are needed
1) The Vulkan/SPIR-V validation rules dictates entry functions to have
signature of void(void). This requires changes to the function
signature conversion infrastructure within the dialect conversion
framework. When an argument is dropped from the original function
signature, a function can be specified that when invoked will return
the value to use as a replacement for the argument from the original
function.
2) Some changes to the type converter to make the converted type
consistent with the Vulkan/SPIR-V validation rules,
   a) Add support for converting dynamically shaped tensors to
   spv.rtarray type.
   b) Make the global variable of type !spv.ptr<!spv.struct<...>>
3) Generate the entry point operation for the kernel functions and
automatically compute all the interface variables needed

PiperOrigin-RevId: 273784229
This commit is contained in:
Mahesh Ravishankar 2019-10-09 11:25:25 -07:00 committed by A. Unique TensorFlower
parent 171637d4f0
commit e2ed25bc43
5 changed files with 258 additions and 128 deletions

View File

@ -49,16 +49,8 @@ public:
explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter) explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
: basicTypeConverter(basicTypeConverter) {} : basicTypeConverter(basicTypeConverter) {}
/// Convert types to SPIR-V types using the basic type converter. /// Converts types to SPIR-V types using the basic type converter.
Type convertType(Type t) override { Type convertType(Type t) override;
return basicTypeConverter->convertType(t);
}
/// Method to convert argument of a function. The `type` is converted to
/// spv.ptr<type, Uniform>.
// TODO(ravishankarm) : Support other storage classes.
LogicalResult convertSignatureArg(unsigned inputNo, Type type,
SignatureConversion &result) override;
/// Gets the basic type converter. /// Gets the basic type converter.
SPIRVBasicTypeConverter *getBasicTypeConverter() const { SPIRVBasicTypeConverter *getBasicTypeConverter() const {
@ -163,17 +155,20 @@ private:
}; };
/// Legalizes a function as a non-entry function. /// Legalizes a function as a non-entry function.
LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands, LogicalResult lowerFunction(FuncOp funcOp, SPIRVTypeConverter *typeConverter,
SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp); FuncOp &newFuncOp);
/// Legalizes a function as an entry function. /// Legalizes a function as an entry function.
LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands, LogicalResult lowerAsEntryFunction(FuncOp funcOp,
SPIRVTypeConverter *typeConverter, SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp); FuncOp &newFuncOp);
/// Finalizes entry function legalization. Inserts the spv.EntryPoint and
/// spv.ExecutionMode ops.
LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder);
/// Appends to a pattern list additional patterns for translating StandardOps to /// Appends to a pattern list additional patterns for translating StandardOps to
/// SPIR-V ops. /// SPIR-V ops.
void populateStandardToSPIRVPatterns(MLIRContext *context, void populateStandardToSPIRVPatterns(MLIRContext *context,

View File

@ -86,18 +86,15 @@ KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
auto funcOp = cast<FuncOp>(op); auto funcOp = cast<FuncOp>(op);
FuncOp newFuncOp; FuncOp newFuncOp;
if (!gpu::GPUDialect::isKernel(funcOp)) { if (!gpu::GPUDialect::isKernel(funcOp)) {
return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter, return succeeded(lowerFunction(funcOp, &typeConverter, rewriter, newFuncOp))
newFuncOp))
? matchSuccess() ? matchSuccess()
: matchFailure(); : matchFailure();
} }
if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter, if (failed(
newFuncOp))) { lowerAsEntryFunction(funcOp, &typeConverter, rewriter, newFuncOp))) {
return matchFailure(); return matchFailure();
} }
newFuncOp.getOperation()->removeAttr(Identifier::get(
gpu::GPUDialect::getKernelFuncAttrName(), op->getContext()));
return matchSuccess(); return matchSuccess();
} }
@ -164,6 +161,24 @@ void GPUToSPIRVPass::runOnModule() {
&typeConverter))) { &typeConverter))) {
return signalPassFailure(); return signalPassFailure();
} }
// After the SPIR-V modules have been generated, some finalization is needed
// for the entry functions. For example, adding spv.EntryPoint op,
// spv.ExecutionMode op, etc.
for (auto *spvModule : spirvModules) {
for (auto op :
cast<spirv::ModuleOp>(spvModule).getBlock().getOps<FuncOp>()) {
if (gpu::GPUDialect::isKernel(op)) {
OpBuilder builder(op.getContext());
builder.setInsertionPointAfter(op);
if (failed(finalizeEntryFunction(op, builder))) {
return signalPassFailure();
}
op.getOperation()->removeAttr(Identifier::get(
gpu::GPUDialect::getKernelFuncAttrName(), op.getContext()));
}
}
}
} }
OpPassBase<ModuleOp> *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); } OpPassBase<ModuleOp> *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); }

View File

@ -23,6 +23,7 @@
#include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/StandardOps/Ops.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir; using namespace mlir;
@ -30,7 +31,7 @@ using namespace mlir;
// Type Conversion // Type Conversion
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
Type SPIRVBasicTypeConverter::convertType(Type t) { static Type basicTypeConversion(Type t) {
// Check if the type is SPIR-V supported. If so return the type. // Check if the type is SPIR-V supported. If so return the type.
if (spirv::SPIRVDialect::isValidType(t)) { if (spirv::SPIRVDialect::isValidType(t)) {
return t; return t;
@ -42,75 +43,107 @@ Type SPIRVBasicTypeConverter::convertType(Type t) {
} }
if (auto memRefType = t.dyn_cast<MemRefType>()) { if (auto memRefType = t.dyn_cast<MemRefType>()) {
if (memRefType.hasStaticShape()) {
// Convert MemrefType to a multi-dimensional spv.array if size is known.
auto elementType = memRefType.getElementType(); auto elementType = memRefType.getElementType();
if (memRefType.hasStaticShape()) {
// Convert to a multi-dimensional spv.array if size is known.
for (auto size : reverse(memRefType.getShape())) { for (auto size : reverse(memRefType.getShape())) {
elementType = spirv::ArrayType::get(elementType, size); elementType = spirv::ArrayType::get(elementType, size);
} }
// TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
// to support other Storage Classes.
return spirv::PointerType::get(elementType, return spirv::PointerType::get(elementType,
spirv::StorageClass::StorageBuffer); spirv::StorageClass::StorageBuffer);
} else {
// Vulkan SPIR-V validation rules require runtime array type to be the
// last member of a struct.
return spirv::PointerType::get(spirv::RuntimeArrayType::get(elementType),
spirv::StorageClass::StorageBuffer);
} }
} }
return Type(); return Type();
} }
Type SPIRVBasicTypeConverter::convertType(Type t) {
return basicTypeConversion(t);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Entry Function signature Conversion // Entry Function signature Conversion
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult /// Generates the type of variable given the type of object.
SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type, static Type getGlobalVarTypeForEntryFnArg(Type t) {
SignatureConversion &result) { auto convertedType = basicTypeConversion(t);
// Try to convert the given input type. if (auto ptrType = convertedType.dyn_cast<spirv::PointerType>()) {
auto convertedType = basicTypeConverter->convertType(type); if (!ptrType.getPointeeType().isa<spirv::StructType>()) {
// TODO(ravishankarm) : Vulkan spec requires these to be a return spirv::PointerType::get(
// spirv::StructType. This is not a SPIR-V requirement, so just making this a spirv::StructType::get(ptrType.getPointeeType()),
// pointer type for now. ptrType.getStorageClass());
if (!convertedType) }
return failure(); } else {
// For arguments to entry functions, convert the type into a pointer type if return spirv::PointerType::get(spirv::StructType::get(convertedType),
// it is already not one, unless the original type was an index type.
// TODO(ravishankarm): For arguments that are of index type, keep the
// arguments as the scalar converted type, i.e. i32. These are still not
// handled effectively. These are potentially best handled as specialization
// constants.
if (!convertedType.isa<spirv::PointerType>() && !type.isa<IndexType>()) {
// TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
// to support other Storage classes.
convertedType = spirv::PointerType::get(convertedType,
spirv::StorageClass::StorageBuffer); spirv::StorageClass::StorageBuffer);
} }
return convertedType;
// Add the new inputs.
result.addInputs(inputNo, convertedType);
return success();
} }
static LogicalResult lowerFunctionImpl( Type SPIRVTypeConverter::convertType(Type t) {
FuncOp funcOp, ArrayRef<Value *> operands, return getGlobalVarTypeForEntryFnArg(t);
ConversionPatternRewriter &rewriter, TypeConverter *typeConverter,
TypeConverter::SignatureConversion &signatureConverter, FuncOp &newFuncOp) {
auto fnType = funcOp.getType();
if (fnType.getNumResults()) {
return funcOp.emitError("SPIR-V dialect only supports functions with no "
"return values right now");
} }
for (auto &argType : enumerate(fnType.getInputs())) { /// Computes the replacement value for an argument of an entry function. It
// Get the type of the argument /// allocates a global variable for this argument and adds statements in the
if (failed(typeConverter->convertSignatureArg( /// entry block to get a replacement value within function scope.
argType.index(), argType.value(), signatureConverter))) { static Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter,
return funcOp.emitError("unable to convert argument type ") size_t origArgNum,
<< argType.value() << " to SPIR-V type"; Value *origArg) {
// Create a global variable for this argument.
auto insertionOp = rewriter.getInsertionBlock()->getParent();
auto module = insertionOp->getParentOfType<spirv::ModuleOp>();
if (!module) {
return nullptr;
} }
auto funcOp = insertionOp->getParentOfType<FuncOp>();
spirv::GlobalVariableOp var;
{
OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
rewriter.setInsertionPointToStart(&module.getBlock());
std::string varName =
funcOp.getName().str() + "_arg_" + std::to_string(origArgNum);
var = rewriter.create<spirv::GlobalVariableOp>(
funcOp.getLoc(),
rewriter.getTypeAttr(getGlobalVarTypeForEntryFnArg(origArg->getType())),
rewriter.getStringAttr(varName), nullptr);
var.setAttr(
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
rewriter.getI32IntegerAttr(0));
var.setAttr(
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
rewriter.getI32IntegerAttr(origArgNum));
}
// Insert the addressOf and load instructions, to get back the converted value
// type.
auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
auto zero = rewriter.create<spirv::ConstantOp>(funcOp.getLoc(),
rewriter.getIntegerType(32),
rewriter.getI32IntegerAttr(0));
auto accessChain = rewriter.create<spirv::AccessChainOp>(
funcOp.getLoc(), addressOf.pointer(), zero.constant());
// If the original argument is a tensor/memref type, the value is not
// loaded. Instead the pointer value is returned to allow its use in access
// chain ops.
auto origArgType = origArg->getType();
if (origArgType.isa<MemRefType>()) {
return accessChain;
}
return rewriter.create<spirv::LoadOp>(
funcOp.getLoc(), accessChain.component_ptr(), /*memory_access=*/nullptr,
/*alignment=*/nullptr);
} }
static FuncOp applySignatureConversion(
FuncOp funcOp, ConversionPatternRewriter &rewriter,
TypeConverter::SignatureConversion &signatureConverter) {
// Create a new function with an updated signature. // Create a new function with an updated signature.
newFuncOp = rewriter.cloneWithoutRegions(funcOp); auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end()); newFuncOp.end());
newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(), newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
@ -119,72 +152,113 @@ static LogicalResult lowerFunctionImpl(
// Tell the rewriter to convert the region signature. // Tell the rewriter to convert the region signature.
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
rewriter.replaceOp(funcOp.getOperation(), llvm::None); rewriter.replaceOp(funcOp.getOperation(), llvm::None);
return newFuncOp;
}
/// Gets the global variables that need to be specified as interface variable
/// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
LogicalResult getInterfaceVariables(FuncOp funcOp,
SmallVectorImpl<Attribute> &interfaceVars) {
auto module = funcOp.getParentOfType<spirv::ModuleOp>();
if (!module) {
return failure();
}
llvm::SetVector<Operation *> interfaceVarSet;
for (auto &block : funcOp) {
// TODO(ravishankarm) : This should in reality traverse the entry function
// call graph and collect all the interfaces. For now, just traverse the
// instructions in this function.
auto callOps = block.getOps<CallOp>();
if (std::distance(callOps.begin(), callOps.end())) {
return funcOp.emitError("Collecting interface variables through function "
"calls unimplemented");
}
for (auto op : block.getOps<spirv::AddressOfOp>()) {
auto var = module.lookupSymbol<spirv::GlobalVariableOp>(op.variable());
if (var.type().cast<spirv::PointerType>().getStorageClass() ==
spirv::StorageClass::StorageBuffer) {
continue;
}
interfaceVarSet.insert(var.getOperation());
}
}
for (auto &var : interfaceVarSet) {
interfaceVars.push_back(SymbolRefAttr::get(
cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
}
return success(); return success();
} }
namespace mlir { namespace mlir {
LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands, LogicalResult lowerFunction(FuncOp funcOp, SPIRVTypeConverter *typeConverter,
SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp) { FuncOp &newFuncOp) {
auto fnType = funcOp.getType(); auto fnType = funcOp.getType();
if (fnType.getNumResults()) {
return funcOp.emitError("SPIR-V lowering only supports functions with no "
"return values right now");
}
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
return lowerFunctionImpl(funcOp, operands, rewriter, auto basicTypeConverter = typeConverter->getBasicTypeConverter();
typeConverter->getBasicTypeConverter(), for (auto origArgType : enumerate(fnType.getInputs())) {
signatureConverter, newFuncOp); auto convertedType = basicTypeConverter->convertType(origArgType.value());
if (!convertedType) {
return funcOp.emitError("unable to convert argument of type '")
<< convertedType << "'";
}
signatureConverter.addInputs(origArgType.index(), convertedType);
}
newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
return success();
} }
LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands, LogicalResult lowerAsEntryFunction(FuncOp funcOp,
SPIRVTypeConverter *typeConverter, SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp) { FuncOp &newFuncOp) {
auto fnType = funcOp.getType(); auto fnType = funcOp.getType();
if (fnType.getNumResults()) {
return funcOp.emitError("SPIR-V lowering only supports functions with no "
"return values right now");
}
// For entry functions need to make the signature void(void). Compute the
// replacement value for all arguments and replace all uses.
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
if (failed(lowerFunctionImpl(funcOp, operands, rewriter, typeConverter, {
signatureConverter, newFuncOp))) { OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
rewriter.setInsertionPointToStart(&funcOp.front());
for (auto origArg : enumerate(funcOp.getArguments())) {
auto replacement = createAndLoadGlobalVarForEntryFnArg(
rewriter, origArg.index(), origArg.value());
rewriter.replaceUsesOfBlockArgument(origArg.value(), replacement);
}
}
newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
return success();
}
LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder) {
// Add the spv.EntryPointOp after collecting all the interface variables
// needed.
SmallVector<Attribute, 1> interfaceVars;
if (failed(getInterfaceVariables(newFuncOp, interfaceVars))) {
return failure(); return failure();
} }
// Create spv.globalVariable ops for each of the arguments. These need to be builder.create<spirv::EntryPointOp>(newFuncOp.getLoc(),
// bound by the runtime. For now use descriptor_set 0, and arg number as the spirv::ExecutionModel::GLCompute,
// binding number. newFuncOp, interfaceVars);
auto module = funcOp.getParentOfType<spirv::ModuleOp>(); // Specify the spv.ExecutionModeOp.
if (!module) {
return funcOp.emitError("expected op to be within a spv.module"); /// TODO(ravishankarm): Vulkan environment for SPIR-V requires "either a
} /// LocalSize execution mode or an object decorated with the WorkgroupSize
auto ip = rewriter.saveInsertionPoint(); /// decoration must be specified." Better approach is to use the
rewriter.setInsertionPointToStart(&module.getBlock()); /// WorkgroupSize GlobalVariable with initializer being a specialization
SmallVector<Attribute, 4> interface; /// constant. But current support for specialization constant does not allow
for (auto &convertedArgType : /// for this. So for now use the execution mode. Hard-wiring this to {1, 1,
llvm::enumerate(signatureConverter.getConvertedTypes())) { /// 1} for now. To be fixed ASAP.
// TODO(ravishankarm) : The arguments to the converted function are either builder.create<spirv::ExecutionModeOp>(newFuncOp.getLoc(), newFuncOp,
// spirv::PointerType or i32 type, the latter due to conversion of index spirv::ExecutionMode::LocalSize,
// type to i32. Eventually entry function should be of signature ArrayRef<int32_t>{1, 1, 1});
// void(void). Arguments converted to spirv::PointerType, will be made
// variables and those converted to i32 will be made specialization
// constants. Latter is not implemented.
if (!convertedArgType.value().isa<spirv::PointerType>()) {
continue;
}
std::string varName = funcOp.getName().str() + "_arg_" +
std::to_string(convertedArgType.index());
auto variableOp = rewriter.create<spirv::GlobalVariableOp>(
funcOp.getLoc(), rewriter.getTypeAttr(convertedArgType.value()),
rewriter.getStringAttr(varName), nullptr);
variableOp.setAttr("descriptor_set", rewriter.getI32IntegerAttr(0));
variableOp.setAttr("binding",
rewriter.getI32IntegerAttr(convertedArgType.index()));
interface.push_back(rewriter.getSymbolRefAttr(variableOp.sym_name()));
}
// Create an entry point instruction for this function.
// TODO(ravishankarm) : Add execution mode for the entry function
rewriter.setInsertionPoint(&(module.getBlock().back()));
rewriter.create<spirv::EntryPointOp>(
funcOp.getLoc(),
rewriter.getI32IntegerAttr(
static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
rewriter.getSymbolRefAttr(newFuncOp.getName()),
rewriter.getArrayAttr(interface));
rewriter.restoreInsertionPoint(ip);
return success(); return success();
} }
} // namespace mlir } // namespace mlir

View File

@ -16,13 +16,52 @@ module attributes {gpu.container_module} {
} }
// CHECK-LABEL: spv.module "Logical" "GLSL450" // CHECK-LABEL: spv.module "Logical" "GLSL450"
// CHECK: spv.globalVariable {{@.*}} bind(0, 0) : [[TYPE1:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
// CHECK-NEXT: spv.globalVariable {{@.*}} bind(0, 1) : [[TYPE2:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
// CHECK-NEXT: spv.globalVariable {{@.*}} bind(0, 2) : [[TYPE3:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
// CHECK: func @load_store_kernel([[ARG0:%.*]]: [[TYPE1]], [[ARG1:%.*]]: [[TYPE2]], [[ARG2:%.*]]: [[TYPE3]], [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32, [[ARG5:%.*]]: i32, [[ARG6:%.*]]: i32)
module @kernels attributes {gpu.kernel_module} { module @kernels attributes {gpu.kernel_module} {
// CHECK-DAG: spv.globalVariable [[WORKGROUPSIZEVAR:@.*]] built_in("WorkgroupSize") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[NUMWORKGROUPSVAR:@.*]] built_in("NumWorkgroups") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
// CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
// CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
// CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr<!spv.struct<i32>, StorageBuffer>
// CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr<!spv.struct<i32>, StorageBuffer>
// CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr<!spv.struct<i32>, StorageBuffer>
// CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr<!spv.struct<i32>, StorageBuffer>
// CHECK: func [[FN:@.*]]()
func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index)
attributes {gpu.kernel} { attributes {gpu.kernel} {
// CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
// CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
// CHECK: [[ARG0:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
// CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
// CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
// CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
// CHECK: [[ADDRESSARG2:%.*]] = spv._address_of [[VAR2]]
// CHECK: [[CONST2:%.*]] = spv.constant 0 : i32
// CHECK: [[ARG2:%.*]] = spv.AccessChain [[ADDRESSARG2]]{{\[}}[[CONST2]]
// CHECK: [[ADDRESSARG3:%.*]] = spv._address_of [[VAR3]]
// CHECK: [[CONST3:%.*]] = spv.constant 0 : i32
// CHECK: [[ARG3PTR:%.*]] = spv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
// CHECK: [[ARG3:%.*]] = spv.Load "StorageBuffer" [[ARG3PTR]]
// CHECK: [[ADDRESSARG4:%.*]] = spv._address_of [[VAR4]]
// CHECK: [[CONST4:%.*]] = spv.constant 0 : i32
// CHECK: [[ARG4PTR:%.*]] = spv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
// CHECK: [[ARG4:%.*]] = spv.Load "StorageBuffer" [[ARG4PTR]]
// CHECK: [[ADDRESSARG5:%.*]] = spv._address_of [[VAR5]]
// CHECK: [[CONST5:%.*]] = spv.constant 0 : i32
// CHECK: [[ARG5PTR:%.*]] = spv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
// CHECK: [[ARG5:%.*]] = spv.Load "StorageBuffer" [[ARG5PTR]]
// CHECK: [[ADDRESSARG6:%.*]] = spv._address_of [[VAR6]]
// CHECK: [[CONST6:%.*]] = spv.constant 0 : i32
// CHECK: [[ARG6PTR:%.*]] = spv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
// CHECK: [[ARG6:%.*]] = spv.Load "StorageBuffer" [[ARG6PTR]]
// CHECK: [[ADDRESSWORKGROUPID:%.*]] = spv._address_of [[WORKGROUPIDVAR]]
// CHECK: [[WORKGROUPID:%.*]] = spv.Load "Input" [[ADDRESSWORKGROUPID]]
// CHECK: [[WORKGROUPIDX:%.*]] = spv.CompositeExtract [[WORKGROUPID]]{{\[}}0 : i32{{\]}}
// CHECK: [[ADDRESSLOCALINVOCATIONID:%.*]] = spv._address_of [[LOCALINVOCATIONIDVAR]]
// CHECK: [[LOCALINVOCATIONID:%.*]] = spv.Load "Input" [[ADDRESSLOCALINVOCATIONID]]
// CHECK: [[LOCALINVOCATIONIDX:%.*]] = spv.CompositeExtract [[LOCALINVOCATIONID]]{{\[}}0 : i32{{\]}}
%0 = "gpu.block_id"() {dimension = "x"} : () -> index %0 = "gpu.block_id"() {dimension = "x"} : () -> index
%1 = "gpu.block_id"() {dimension = "y"} : () -> index %1 = "gpu.block_id"() {dimension = "y"} : () -> index
%2 = "gpu.block_id"() {dimension = "z"} : () -> index %2 = "gpu.block_id"() {dimension = "z"} : () -> index
@ -35,9 +74,9 @@ module attributes {gpu.container_module} {
%9 = "gpu.block_dim"() {dimension = "x"} : () -> index %9 = "gpu.block_dim"() {dimension = "x"} : () -> index
%10 = "gpu.block_dim"() {dimension = "y"} : () -> index %10 = "gpu.block_dim"() {dimension = "y"} : () -> index
%11 = "gpu.block_dim"() {dimension = "z"} : () -> index %11 = "gpu.block_dim"() {dimension = "z"} : () -> index
// CHECK: [[INDEX1:%.*]] = spv.IAdd [[ARG3]], {{%.*}} // CHECK: [[INDEX1:%.*]] = spv.IAdd [[ARG3]], [[WORKGROUPIDX]]
%12 = addi %arg3, %0 : index %12 = addi %arg3, %0 : index
// CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], {{%.*}} // CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], [[LOCALINVOCATIONIDX]]
%13 = addi %arg4, %3 : index %13 = addi %arg4, %3 : index
// CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} // CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
// CHECK-NEXT: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[PTR1]] // CHECK-NEXT: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[PTR1]]

View File

@ -2,15 +2,23 @@
module attributes {gpu.container_module} { module attributes {gpu.container_module} {
// CHECK: spv.module "Logical" "GLSL450" {
// CHECK-NEXT: spv.globalVariable [[VAR1:@.*]] bind(0, 0) : !spv.ptr<f32, StorageBuffer>
// CHECK-NEXT: spv.globalVariable [[VAR2:@.*]] bind(0, 1) : !spv.ptr<!spv.array<12 x f32>, StorageBuffer>
// CHECK-NEXT: func @kernel_1
// CHECK-NEXT: spv.Return
// CHECK: spv.EntryPoint "GLCompute" @kernel_1, [[VAR1]], [[VAR2]]
module @kernels attributes {gpu.kernel_module} { module @kernels attributes {gpu.kernel_module} {
// CHECK: spv.module "Logical" "GLSL450" {
// CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<f32>, StorageBuffer>
// CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32>>, StorageBuffer>
// CHECK: func [[FN:@.*]]()
func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32, 1>) func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32, 1>)
attributes { gpu.kernel } { attributes { gpu.kernel } {
// CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
// CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
// CHECK: [[ARG0PTR:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
// CHECK: [[ARG0:%.*]] = spv.Load "StorageBuffer" [[ARG0PTR]]
// CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
// CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
// CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
// CHECK-NEXT: spv.Return
// CHECK: spv.EntryPoint "GLCompute" [[FN]]
// CHECK: spv.ExecutionMode [[FN]] "LocalSize"
return return
} }
} }
@ -23,5 +31,4 @@ module attributes {gpu.container_module} {
: (index, index, index, index, index, index, f32, memref<12xf32, 1>) -> () : (index, index, index, index, index, index, f32, memref<12xf32, 1>) -> ()
return return
} }
} }