forked from OSchip/llvm-project
Make SPIR-V lowering infrastructure follow Vulkan SPIR-V validation.
The lowering infrastructure needs to be enhanced to lower into a spv.Module that is consistent with the SPIR-V spec. The following changes are needed 1) The Vulkan/SPIR-V validation rules dictates entry functions to have signature of void(void). This requires changes to the function signature conversion infrastructure within the dialect conversion framework. When an argument is dropped from the original function signature, a function can be specified that when invoked will return the value to use as a replacement for the argument from the original function. 2) Some changes to the type converter to make the converted type consistent with the Vulkan/SPIR-V validation rules, a) Add support for converting dynamically shaped tensors to spv.rtarray type. b) Make the global variable of type !spv.ptr<!spv.struct<...>> 3) Generate the entry point operation for the kernel functions and automatically compute all the interface variables needed PiperOrigin-RevId: 273784229
This commit is contained in:
parent
171637d4f0
commit
e2ed25bc43
|
@ -49,16 +49,8 @@ public:
|
||||||
explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
|
explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
|
||||||
: basicTypeConverter(basicTypeConverter) {}
|
: basicTypeConverter(basicTypeConverter) {}
|
||||||
|
|
||||||
/// Convert types to SPIR-V types using the basic type converter.
|
/// Converts types to SPIR-V types using the basic type converter.
|
||||||
Type convertType(Type t) override {
|
Type convertType(Type t) override;
|
||||||
return basicTypeConverter->convertType(t);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Method to convert argument of a function. The `type` is converted to
|
|
||||||
/// spv.ptr<type, Uniform>.
|
|
||||||
// TODO(ravishankarm) : Support other storage classes.
|
|
||||||
LogicalResult convertSignatureArg(unsigned inputNo, Type type,
|
|
||||||
SignatureConversion &result) override;
|
|
||||||
|
|
||||||
/// Gets the basic type converter.
|
/// Gets the basic type converter.
|
||||||
SPIRVBasicTypeConverter *getBasicTypeConverter() const {
|
SPIRVBasicTypeConverter *getBasicTypeConverter() const {
|
||||||
|
@ -163,17 +155,20 @@ private:
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Legalizes a function as a non-entry function.
|
/// Legalizes a function as a non-entry function.
|
||||||
LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
LogicalResult lowerFunction(FuncOp funcOp, SPIRVTypeConverter *typeConverter,
|
||||||
SPIRVTypeConverter *typeConverter,
|
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
FuncOp &newFuncOp);
|
FuncOp &newFuncOp);
|
||||||
|
|
||||||
/// Legalizes a function as an entry function.
|
/// Legalizes a function as an entry function.
|
||||||
LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
LogicalResult lowerAsEntryFunction(FuncOp funcOp,
|
||||||
SPIRVTypeConverter *typeConverter,
|
SPIRVTypeConverter *typeConverter,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
FuncOp &newFuncOp);
|
FuncOp &newFuncOp);
|
||||||
|
|
||||||
|
/// Finalizes entry function legalization. Inserts the spv.EntryPoint and
|
||||||
|
/// spv.ExecutionMode ops.
|
||||||
|
LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder);
|
||||||
|
|
||||||
/// Appends to a pattern list additional patterns for translating StandardOps to
|
/// Appends to a pattern list additional patterns for translating StandardOps to
|
||||||
/// SPIR-V ops.
|
/// SPIR-V ops.
|
||||||
void populateStandardToSPIRVPatterns(MLIRContext *context,
|
void populateStandardToSPIRVPatterns(MLIRContext *context,
|
||||||
|
|
|
@ -86,18 +86,15 @@ KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
auto funcOp = cast<FuncOp>(op);
|
auto funcOp = cast<FuncOp>(op);
|
||||||
FuncOp newFuncOp;
|
FuncOp newFuncOp;
|
||||||
if (!gpu::GPUDialect::isKernel(funcOp)) {
|
if (!gpu::GPUDialect::isKernel(funcOp)) {
|
||||||
return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter,
|
return succeeded(lowerFunction(funcOp, &typeConverter, rewriter, newFuncOp))
|
||||||
newFuncOp))
|
|
||||||
? matchSuccess()
|
? matchSuccess()
|
||||||
: matchFailure();
|
: matchFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter,
|
if (failed(
|
||||||
newFuncOp))) {
|
lowerAsEntryFunction(funcOp, &typeConverter, rewriter, newFuncOp))) {
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
}
|
}
|
||||||
newFuncOp.getOperation()->removeAttr(Identifier::get(
|
|
||||||
gpu::GPUDialect::getKernelFuncAttrName(), op->getContext()));
|
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,6 +161,24 @@ void GPUToSPIRVPass::runOnModule() {
|
||||||
&typeConverter))) {
|
&typeConverter))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// After the SPIR-V modules have been generated, some finalization is needed
|
||||||
|
// for the entry functions. For example, adding spv.EntryPoint op,
|
||||||
|
// spv.ExecutionMode op, etc.
|
||||||
|
for (auto *spvModule : spirvModules) {
|
||||||
|
for (auto op :
|
||||||
|
cast<spirv::ModuleOp>(spvModule).getBlock().getOps<FuncOp>()) {
|
||||||
|
if (gpu::GPUDialect::isKernel(op)) {
|
||||||
|
OpBuilder builder(op.getContext());
|
||||||
|
builder.setInsertionPointAfter(op);
|
||||||
|
if (failed(finalizeEntryFunction(op, builder))) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
op.getOperation()->removeAttr(Identifier::get(
|
||||||
|
gpu::GPUDialect::getKernelFuncAttrName(), op.getContext()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
OpPassBase<ModuleOp> *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); }
|
OpPassBase<ModuleOp> *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); }
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
@ -30,7 +31,7 @@ using namespace mlir;
|
||||||
// Type Conversion
|
// Type Conversion
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Type SPIRVBasicTypeConverter::convertType(Type t) {
|
static Type basicTypeConversion(Type t) {
|
||||||
// Check if the type is SPIR-V supported. If so return the type.
|
// Check if the type is SPIR-V supported. If so return the type.
|
||||||
if (spirv::SPIRVDialect::isValidType(t)) {
|
if (spirv::SPIRVDialect::isValidType(t)) {
|
||||||
return t;
|
return t;
|
||||||
|
@ -42,75 +43,107 @@ Type SPIRVBasicTypeConverter::convertType(Type t) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto memRefType = t.dyn_cast<MemRefType>()) {
|
if (auto memRefType = t.dyn_cast<MemRefType>()) {
|
||||||
if (memRefType.hasStaticShape()) {
|
|
||||||
// Convert MemrefType to a multi-dimensional spv.array if size is known.
|
|
||||||
auto elementType = memRefType.getElementType();
|
auto elementType = memRefType.getElementType();
|
||||||
|
if (memRefType.hasStaticShape()) {
|
||||||
|
// Convert to a multi-dimensional spv.array if size is known.
|
||||||
for (auto size : reverse(memRefType.getShape())) {
|
for (auto size : reverse(memRefType.getShape())) {
|
||||||
elementType = spirv::ArrayType::get(elementType, size);
|
elementType = spirv::ArrayType::get(elementType, size);
|
||||||
}
|
}
|
||||||
// TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
|
|
||||||
// to support other Storage Classes.
|
|
||||||
return spirv::PointerType::get(elementType,
|
return spirv::PointerType::get(elementType,
|
||||||
spirv::StorageClass::StorageBuffer);
|
spirv::StorageClass::StorageBuffer);
|
||||||
|
} else {
|
||||||
|
// Vulkan SPIR-V validation rules require runtime array type to be the
|
||||||
|
// last member of a struct.
|
||||||
|
return spirv::PointerType::get(spirv::RuntimeArrayType::get(elementType),
|
||||||
|
spirv::StorageClass::StorageBuffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Type();
|
return Type();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Type SPIRVBasicTypeConverter::convertType(Type t) {
|
||||||
|
return basicTypeConversion(t);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Entry Function signature Conversion
|
// Entry Function signature Conversion
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult
|
/// Generates the type of variable given the type of object.
|
||||||
SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
|
static Type getGlobalVarTypeForEntryFnArg(Type t) {
|
||||||
SignatureConversion &result) {
|
auto convertedType = basicTypeConversion(t);
|
||||||
// Try to convert the given input type.
|
if (auto ptrType = convertedType.dyn_cast<spirv::PointerType>()) {
|
||||||
auto convertedType = basicTypeConverter->convertType(type);
|
if (!ptrType.getPointeeType().isa<spirv::StructType>()) {
|
||||||
// TODO(ravishankarm) : Vulkan spec requires these to be a
|
return spirv::PointerType::get(
|
||||||
// spirv::StructType. This is not a SPIR-V requirement, so just making this a
|
spirv::StructType::get(ptrType.getPointeeType()),
|
||||||
// pointer type for now.
|
ptrType.getStorageClass());
|
||||||
if (!convertedType)
|
}
|
||||||
return failure();
|
} else {
|
||||||
// For arguments to entry functions, convert the type into a pointer type if
|
return spirv::PointerType::get(spirv::StructType::get(convertedType),
|
||||||
// it is already not one, unless the original type was an index type.
|
|
||||||
// TODO(ravishankarm): For arguments that are of index type, keep the
|
|
||||||
// arguments as the scalar converted type, i.e. i32. These are still not
|
|
||||||
// handled effectively. These are potentially best handled as specialization
|
|
||||||
// constants.
|
|
||||||
if (!convertedType.isa<spirv::PointerType>() && !type.isa<IndexType>()) {
|
|
||||||
// TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
|
|
||||||
// to support other Storage classes.
|
|
||||||
convertedType = spirv::PointerType::get(convertedType,
|
|
||||||
spirv::StorageClass::StorageBuffer);
|
spirv::StorageClass::StorageBuffer);
|
||||||
}
|
}
|
||||||
|
return convertedType;
|
||||||
// Add the new inputs.
|
|
||||||
result.addInputs(inputNo, convertedType);
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult lowerFunctionImpl(
|
Type SPIRVTypeConverter::convertType(Type t) {
|
||||||
FuncOp funcOp, ArrayRef<Value *> operands,
|
return getGlobalVarTypeForEntryFnArg(t);
|
||||||
ConversionPatternRewriter &rewriter, TypeConverter *typeConverter,
|
|
||||||
TypeConverter::SignatureConversion &signatureConverter, FuncOp &newFuncOp) {
|
|
||||||
auto fnType = funcOp.getType();
|
|
||||||
|
|
||||||
if (fnType.getNumResults()) {
|
|
||||||
return funcOp.emitError("SPIR-V dialect only supports functions with no "
|
|
||||||
"return values right now");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &argType : enumerate(fnType.getInputs())) {
|
/// Computes the replacement value for an argument of an entry function. It
|
||||||
// Get the type of the argument
|
/// allocates a global variable for this argument and adds statements in the
|
||||||
if (failed(typeConverter->convertSignatureArg(
|
/// entry block to get a replacement value within function scope.
|
||||||
argType.index(), argType.value(), signatureConverter))) {
|
static Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter,
|
||||||
return funcOp.emitError("unable to convert argument type ")
|
size_t origArgNum,
|
||||||
<< argType.value() << " to SPIR-V type";
|
Value *origArg) {
|
||||||
|
// Create a global variable for this argument.
|
||||||
|
auto insertionOp = rewriter.getInsertionBlock()->getParent();
|
||||||
|
auto module = insertionOp->getParentOfType<spirv::ModuleOp>();
|
||||||
|
if (!module) {
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
auto funcOp = insertionOp->getParentOfType<FuncOp>();
|
||||||
|
spirv::GlobalVariableOp var;
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&module.getBlock());
|
||||||
|
std::string varName =
|
||||||
|
funcOp.getName().str() + "_arg_" + std::to_string(origArgNum);
|
||||||
|
var = rewriter.create<spirv::GlobalVariableOp>(
|
||||||
|
funcOp.getLoc(),
|
||||||
|
rewriter.getTypeAttr(getGlobalVarTypeForEntryFnArg(origArg->getType())),
|
||||||
|
rewriter.getStringAttr(varName), nullptr);
|
||||||
|
var.setAttr(
|
||||||
|
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
|
||||||
|
rewriter.getI32IntegerAttr(0));
|
||||||
|
var.setAttr(
|
||||||
|
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
|
||||||
|
rewriter.getI32IntegerAttr(origArgNum));
|
||||||
|
}
|
||||||
|
// Insert the addressOf and load instructions, to get back the converted value
|
||||||
|
// type.
|
||||||
|
auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
|
||||||
|
auto zero = rewriter.create<spirv::ConstantOp>(funcOp.getLoc(),
|
||||||
|
rewriter.getIntegerType(32),
|
||||||
|
rewriter.getI32IntegerAttr(0));
|
||||||
|
auto accessChain = rewriter.create<spirv::AccessChainOp>(
|
||||||
|
funcOp.getLoc(), addressOf.pointer(), zero.constant());
|
||||||
|
// If the original argument is a tensor/memref type, the value is not
|
||||||
|
// loaded. Instead the pointer value is returned to allow its use in access
|
||||||
|
// chain ops.
|
||||||
|
auto origArgType = origArg->getType();
|
||||||
|
if (origArgType.isa<MemRefType>()) {
|
||||||
|
return accessChain;
|
||||||
|
}
|
||||||
|
return rewriter.create<spirv::LoadOp>(
|
||||||
|
funcOp.getLoc(), accessChain.component_ptr(), /*memory_access=*/nullptr,
|
||||||
|
/*alignment=*/nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FuncOp applySignatureConversion(
|
||||||
|
FuncOp funcOp, ConversionPatternRewriter &rewriter,
|
||||||
|
TypeConverter::SignatureConversion &signatureConverter) {
|
||||||
// Create a new function with an updated signature.
|
// Create a new function with an updated signature.
|
||||||
newFuncOp = rewriter.cloneWithoutRegions(funcOp);
|
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
|
||||||
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
|
||||||
newFuncOp.end());
|
newFuncOp.end());
|
||||||
newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
|
newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
|
||||||
|
@ -119,72 +152,113 @@ static LogicalResult lowerFunctionImpl(
|
||||||
// Tell the rewriter to convert the region signature.
|
// Tell the rewriter to convert the region signature.
|
||||||
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
|
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
|
||||||
rewriter.replaceOp(funcOp.getOperation(), llvm::None);
|
rewriter.replaceOp(funcOp.getOperation(), llvm::None);
|
||||||
|
return newFuncOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets the global variables that need to be specified as interface variable
|
||||||
|
/// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
|
||||||
|
LogicalResult getInterfaceVariables(FuncOp funcOp,
|
||||||
|
SmallVectorImpl<Attribute> &interfaceVars) {
|
||||||
|
auto module = funcOp.getParentOfType<spirv::ModuleOp>();
|
||||||
|
if (!module) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
llvm::SetVector<Operation *> interfaceVarSet;
|
||||||
|
for (auto &block : funcOp) {
|
||||||
|
// TODO(ravishankarm) : This should in reality traverse the entry function
|
||||||
|
// call graph and collect all the interfaces. For now, just traverse the
|
||||||
|
// instructions in this function.
|
||||||
|
auto callOps = block.getOps<CallOp>();
|
||||||
|
if (std::distance(callOps.begin(), callOps.end())) {
|
||||||
|
return funcOp.emitError("Collecting interface variables through function "
|
||||||
|
"calls unimplemented");
|
||||||
|
}
|
||||||
|
for (auto op : block.getOps<spirv::AddressOfOp>()) {
|
||||||
|
auto var = module.lookupSymbol<spirv::GlobalVariableOp>(op.variable());
|
||||||
|
if (var.type().cast<spirv::PointerType>().getStorageClass() ==
|
||||||
|
spirv::StorageClass::StorageBuffer) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
interfaceVarSet.insert(var.getOperation());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto &var : interfaceVarSet) {
|
||||||
|
interfaceVars.push_back(SymbolRefAttr::get(
|
||||||
|
cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
LogicalResult lowerFunction(FuncOp funcOp, SPIRVTypeConverter *typeConverter,
|
||||||
SPIRVTypeConverter *typeConverter,
|
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
FuncOp &newFuncOp) {
|
FuncOp &newFuncOp) {
|
||||||
auto fnType = funcOp.getType();
|
auto fnType = funcOp.getType();
|
||||||
|
if (fnType.getNumResults()) {
|
||||||
|
return funcOp.emitError("SPIR-V lowering only supports functions with no "
|
||||||
|
"return values right now");
|
||||||
|
}
|
||||||
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
|
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
|
||||||
return lowerFunctionImpl(funcOp, operands, rewriter,
|
auto basicTypeConverter = typeConverter->getBasicTypeConverter();
|
||||||
typeConverter->getBasicTypeConverter(),
|
for (auto origArgType : enumerate(fnType.getInputs())) {
|
||||||
signatureConverter, newFuncOp);
|
auto convertedType = basicTypeConverter->convertType(origArgType.value());
|
||||||
|
if (!convertedType) {
|
||||||
|
return funcOp.emitError("unable to convert argument of type '")
|
||||||
|
<< convertedType << "'";
|
||||||
|
}
|
||||||
|
signatureConverter.addInputs(origArgType.index(), convertedType);
|
||||||
|
}
|
||||||
|
newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
LogicalResult lowerAsEntryFunction(FuncOp funcOp,
|
||||||
SPIRVTypeConverter *typeConverter,
|
SPIRVTypeConverter *typeConverter,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
FuncOp &newFuncOp) {
|
FuncOp &newFuncOp) {
|
||||||
auto fnType = funcOp.getType();
|
auto fnType = funcOp.getType();
|
||||||
|
if (fnType.getNumResults()) {
|
||||||
|
return funcOp.emitError("SPIR-V lowering only supports functions with no "
|
||||||
|
"return values right now");
|
||||||
|
}
|
||||||
|
// For entry functions need to make the signature void(void). Compute the
|
||||||
|
// replacement value for all arguments and replace all uses.
|
||||||
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
|
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
|
||||||
if (failed(lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
|
{
|
||||||
signatureConverter, newFuncOp))) {
|
OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&funcOp.front());
|
||||||
|
for (auto origArg : enumerate(funcOp.getArguments())) {
|
||||||
|
auto replacement = createAndLoadGlobalVarForEntryFnArg(
|
||||||
|
rewriter, origArg.index(), origArg.value());
|
||||||
|
rewriter.replaceUsesOfBlockArgument(origArg.value(), replacement);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder) {
|
||||||
|
// Add the spv.EntryPointOp after collecting all the interface variables
|
||||||
|
// needed.
|
||||||
|
SmallVector<Attribute, 1> interfaceVars;
|
||||||
|
if (failed(getInterfaceVariables(newFuncOp, interfaceVars))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
// Create spv.globalVariable ops for each of the arguments. These need to be
|
builder.create<spirv::EntryPointOp>(newFuncOp.getLoc(),
|
||||||
// bound by the runtime. For now use descriptor_set 0, and arg number as the
|
spirv::ExecutionModel::GLCompute,
|
||||||
// binding number.
|
newFuncOp, interfaceVars);
|
||||||
auto module = funcOp.getParentOfType<spirv::ModuleOp>();
|
// Specify the spv.ExecutionModeOp.
|
||||||
if (!module) {
|
|
||||||
return funcOp.emitError("expected op to be within a spv.module");
|
/// TODO(ravishankarm): Vulkan environment for SPIR-V requires "either a
|
||||||
}
|
/// LocalSize execution mode or an object decorated with the WorkgroupSize
|
||||||
auto ip = rewriter.saveInsertionPoint();
|
/// decoration must be specified." Better approach is to use the
|
||||||
rewriter.setInsertionPointToStart(&module.getBlock());
|
/// WorkgroupSize GlobalVariable with initializer being a specialization
|
||||||
SmallVector<Attribute, 4> interface;
|
/// constant. But current support for specialization constant does not allow
|
||||||
for (auto &convertedArgType :
|
/// for this. So for now use the execution mode. Hard-wiring this to {1, 1,
|
||||||
llvm::enumerate(signatureConverter.getConvertedTypes())) {
|
/// 1} for now. To be fixed ASAP.
|
||||||
// TODO(ravishankarm) : The arguments to the converted function are either
|
builder.create<spirv::ExecutionModeOp>(newFuncOp.getLoc(), newFuncOp,
|
||||||
// spirv::PointerType or i32 type, the latter due to conversion of index
|
spirv::ExecutionMode::LocalSize,
|
||||||
// type to i32. Eventually entry function should be of signature
|
ArrayRef<int32_t>{1, 1, 1});
|
||||||
// void(void). Arguments converted to spirv::PointerType, will be made
|
|
||||||
// variables and those converted to i32 will be made specialization
|
|
||||||
// constants. Latter is not implemented.
|
|
||||||
if (!convertedArgType.value().isa<spirv::PointerType>()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
std::string varName = funcOp.getName().str() + "_arg_" +
|
|
||||||
std::to_string(convertedArgType.index());
|
|
||||||
auto variableOp = rewriter.create<spirv::GlobalVariableOp>(
|
|
||||||
funcOp.getLoc(), rewriter.getTypeAttr(convertedArgType.value()),
|
|
||||||
rewriter.getStringAttr(varName), nullptr);
|
|
||||||
variableOp.setAttr("descriptor_set", rewriter.getI32IntegerAttr(0));
|
|
||||||
variableOp.setAttr("binding",
|
|
||||||
rewriter.getI32IntegerAttr(convertedArgType.index()));
|
|
||||||
interface.push_back(rewriter.getSymbolRefAttr(variableOp.sym_name()));
|
|
||||||
}
|
|
||||||
// Create an entry point instruction for this function.
|
|
||||||
// TODO(ravishankarm) : Add execution mode for the entry function
|
|
||||||
rewriter.setInsertionPoint(&(module.getBlock().back()));
|
|
||||||
rewriter.create<spirv::EntryPointOp>(
|
|
||||||
funcOp.getLoc(),
|
|
||||||
rewriter.getI32IntegerAttr(
|
|
||||||
static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
|
|
||||||
rewriter.getSymbolRefAttr(newFuncOp.getName()),
|
|
||||||
rewriter.getArrayAttr(interface));
|
|
||||||
rewriter.restoreInsertionPoint(ip);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -16,13 +16,52 @@ module attributes {gpu.container_module} {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: spv.module "Logical" "GLSL450"
|
// CHECK-LABEL: spv.module "Logical" "GLSL450"
|
||||||
// CHECK: spv.globalVariable {{@.*}} bind(0, 0) : [[TYPE1:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
|
|
||||||
// CHECK-NEXT: spv.globalVariable {{@.*}} bind(0, 1) : [[TYPE2:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
|
|
||||||
// CHECK-NEXT: spv.globalVariable {{@.*}} bind(0, 2) : [[TYPE3:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
|
|
||||||
// CHECK: func @load_store_kernel([[ARG0:%.*]]: [[TYPE1]], [[ARG1:%.*]]: [[TYPE2]], [[ARG2:%.*]]: [[TYPE3]], [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32, [[ARG5:%.*]]: i32, [[ARG6:%.*]]: i32)
|
|
||||||
module @kernels attributes {gpu.kernel_module} {
|
module @kernels attributes {gpu.kernel_module} {
|
||||||
|
// CHECK-DAG: spv.globalVariable [[WORKGROUPSIZEVAR:@.*]] built_in("WorkgroupSize") : !spv.ptr<vector<3xi32>, Input>
|
||||||
|
// CHECK-DAG: spv.globalVariable [[NUMWORKGROUPSVAR:@.*]] built_in("NumWorkgroups") : !spv.ptr<vector<3xi32>, Input>
|
||||||
|
// CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
|
||||||
|
// CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
|
||||||
|
// CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
|
||||||
|
// CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
|
||||||
|
// CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
|
||||||
|
// CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr<!spv.struct<i32>, StorageBuffer>
|
||||||
|
// CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr<!spv.struct<i32>, StorageBuffer>
|
||||||
|
// CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr<!spv.struct<i32>, StorageBuffer>
|
||||||
|
// CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr<!spv.struct<i32>, StorageBuffer>
|
||||||
|
// CHECK: func [[FN:@.*]]()
|
||||||
func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index)
|
func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index)
|
||||||
attributes {gpu.kernel} {
|
attributes {gpu.kernel} {
|
||||||
|
// CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
|
||||||
|
// CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
|
||||||
|
// CHECK: [[ARG0:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
|
||||||
|
// CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
|
||||||
|
// CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
|
||||||
|
// CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
|
||||||
|
// CHECK: [[ADDRESSARG2:%.*]] = spv._address_of [[VAR2]]
|
||||||
|
// CHECK: [[CONST2:%.*]] = spv.constant 0 : i32
|
||||||
|
// CHECK: [[ARG2:%.*]] = spv.AccessChain [[ADDRESSARG2]]{{\[}}[[CONST2]]
|
||||||
|
// CHECK: [[ADDRESSARG3:%.*]] = spv._address_of [[VAR3]]
|
||||||
|
// CHECK: [[CONST3:%.*]] = spv.constant 0 : i32
|
||||||
|
// CHECK: [[ARG3PTR:%.*]] = spv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
|
||||||
|
// CHECK: [[ARG3:%.*]] = spv.Load "StorageBuffer" [[ARG3PTR]]
|
||||||
|
// CHECK: [[ADDRESSARG4:%.*]] = spv._address_of [[VAR4]]
|
||||||
|
// CHECK: [[CONST4:%.*]] = spv.constant 0 : i32
|
||||||
|
// CHECK: [[ARG4PTR:%.*]] = spv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
|
||||||
|
// CHECK: [[ARG4:%.*]] = spv.Load "StorageBuffer" [[ARG4PTR]]
|
||||||
|
// CHECK: [[ADDRESSARG5:%.*]] = spv._address_of [[VAR5]]
|
||||||
|
// CHECK: [[CONST5:%.*]] = spv.constant 0 : i32
|
||||||
|
// CHECK: [[ARG5PTR:%.*]] = spv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
|
||||||
|
// CHECK: [[ARG5:%.*]] = spv.Load "StorageBuffer" [[ARG5PTR]]
|
||||||
|
// CHECK: [[ADDRESSARG6:%.*]] = spv._address_of [[VAR6]]
|
||||||
|
// CHECK: [[CONST6:%.*]] = spv.constant 0 : i32
|
||||||
|
// CHECK: [[ARG6PTR:%.*]] = spv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
|
||||||
|
// CHECK: [[ARG6:%.*]] = spv.Load "StorageBuffer" [[ARG6PTR]]
|
||||||
|
// CHECK: [[ADDRESSWORKGROUPID:%.*]] = spv._address_of [[WORKGROUPIDVAR]]
|
||||||
|
// CHECK: [[WORKGROUPID:%.*]] = spv.Load "Input" [[ADDRESSWORKGROUPID]]
|
||||||
|
// CHECK: [[WORKGROUPIDX:%.*]] = spv.CompositeExtract [[WORKGROUPID]]{{\[}}0 : i32{{\]}}
|
||||||
|
// CHECK: [[ADDRESSLOCALINVOCATIONID:%.*]] = spv._address_of [[LOCALINVOCATIONIDVAR]]
|
||||||
|
// CHECK: [[LOCALINVOCATIONID:%.*]] = spv.Load "Input" [[ADDRESSLOCALINVOCATIONID]]
|
||||||
|
// CHECK: [[LOCALINVOCATIONIDX:%.*]] = spv.CompositeExtract [[LOCALINVOCATIONID]]{{\[}}0 : i32{{\]}}
|
||||||
%0 = "gpu.block_id"() {dimension = "x"} : () -> index
|
%0 = "gpu.block_id"() {dimension = "x"} : () -> index
|
||||||
%1 = "gpu.block_id"() {dimension = "y"} : () -> index
|
%1 = "gpu.block_id"() {dimension = "y"} : () -> index
|
||||||
%2 = "gpu.block_id"() {dimension = "z"} : () -> index
|
%2 = "gpu.block_id"() {dimension = "z"} : () -> index
|
||||||
|
@ -35,9 +74,9 @@ module attributes {gpu.container_module} {
|
||||||
%9 = "gpu.block_dim"() {dimension = "x"} : () -> index
|
%9 = "gpu.block_dim"() {dimension = "x"} : () -> index
|
||||||
%10 = "gpu.block_dim"() {dimension = "y"} : () -> index
|
%10 = "gpu.block_dim"() {dimension = "y"} : () -> index
|
||||||
%11 = "gpu.block_dim"() {dimension = "z"} : () -> index
|
%11 = "gpu.block_dim"() {dimension = "z"} : () -> index
|
||||||
// CHECK: [[INDEX1:%.*]] = spv.IAdd [[ARG3]], {{%.*}}
|
// CHECK: [[INDEX1:%.*]] = spv.IAdd [[ARG3]], [[WORKGROUPIDX]]
|
||||||
%12 = addi %arg3, %0 : index
|
%12 = addi %arg3, %0 : index
|
||||||
// CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], {{%.*}}
|
// CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], [[LOCALINVOCATIONIDX]]
|
||||||
%13 = addi %arg4, %3 : index
|
%13 = addi %arg4, %3 : index
|
||||||
// CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
|
// CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
|
||||||
// CHECK-NEXT: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[PTR1]]
|
// CHECK-NEXT: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[PTR1]]
|
||||||
|
|
|
@ -2,15 +2,23 @@
|
||||||
|
|
||||||
module attributes {gpu.container_module} {
|
module attributes {gpu.container_module} {
|
||||||
|
|
||||||
// CHECK: spv.module "Logical" "GLSL450" {
|
|
||||||
// CHECK-NEXT: spv.globalVariable [[VAR1:@.*]] bind(0, 0) : !spv.ptr<f32, StorageBuffer>
|
|
||||||
// CHECK-NEXT: spv.globalVariable [[VAR2:@.*]] bind(0, 1) : !spv.ptr<!spv.array<12 x f32>, StorageBuffer>
|
|
||||||
// CHECK-NEXT: func @kernel_1
|
|
||||||
// CHECK-NEXT: spv.Return
|
|
||||||
// CHECK: spv.EntryPoint "GLCompute" @kernel_1, [[VAR1]], [[VAR2]]
|
|
||||||
module @kernels attributes {gpu.kernel_module} {
|
module @kernels attributes {gpu.kernel_module} {
|
||||||
|
// CHECK: spv.module "Logical" "GLSL450" {
|
||||||
|
// CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<f32>, StorageBuffer>
|
||||||
|
// CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32>>, StorageBuffer>
|
||||||
|
// CHECK: func [[FN:@.*]]()
|
||||||
func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32, 1>)
|
func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32, 1>)
|
||||||
attributes { gpu.kernel } {
|
attributes { gpu.kernel } {
|
||||||
|
// CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
|
||||||
|
// CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
|
||||||
|
// CHECK: [[ARG0PTR:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
|
||||||
|
// CHECK: [[ARG0:%.*]] = spv.Load "StorageBuffer" [[ARG0PTR]]
|
||||||
|
// CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
|
||||||
|
// CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
|
||||||
|
// CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
|
||||||
|
// CHECK-NEXT: spv.Return
|
||||||
|
// CHECK: spv.EntryPoint "GLCompute" [[FN]]
|
||||||
|
// CHECK: spv.ExecutionMode [[FN]] "LocalSize"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -23,5 +31,4 @@ module attributes {gpu.container_module} {
|
||||||
: (index, index, index, index, index, index, f32, memref<12xf32, 1>) -> ()
|
: (index, index, index, index, index, index, f32, memref<12xf32, 1>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue