llvm-project/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp

266 lines
11 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to lower attributes that specify the shader ABI
// for the functions in the generated SPIR-V module.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/SPIRV/LayoutUtils.h"
#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
/// Creates a global variable for an argument based on the ABI info.
static spirv::GlobalVariableOp
createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
unsigned argIndex,
spirv::InterfaceVarABIAttr abiInfo) {
auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>();
if (!spirvModule)
return nullptr;
OpBuilder::InsertionGuard moduleInsertionGuard(builder);
builder.setInsertionPoint(funcOp.getOperation());
std::string varName =
funcOp.getName().str() + "_arg_" + std::to_string(argIndex);
// Get the type of variable. If this is a scalar/vector type and has an ABI
// info create a variable of type !spv.ptr<!spv.struct<elementType>>. If not
// it must already be a !spv.ptr<!spv.struct<...>>.
auto varType = funcOp.getType().getInput(argIndex);
if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
auto storageClass = abiInfo.getStorageClass();
if (!storageClass)
return nullptr;
varType =
spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
}
auto varPtrType = varType.cast<spirv::PointerType>();
auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
// Set the offset information.
varPointeeType =
VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>();
varType =
spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
return builder.create<spirv::GlobalVariableOp>(
funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
abiInfo.getBinding());
}
/// Gets the global variables that need to be specified as interface variable
/// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
static LogicalResult
getInterfaceVariables(spirv::FuncOp funcOp,
SmallVectorImpl<Attribute> &interfaceVars) {
auto module = funcOp.getParentOfType<spirv::ModuleOp>();
if (!module) {
return failure();
}
llvm::SetVector<Operation *> interfaceVarSet;
// TODO: This should in reality traverse the entry function
// call graph and collect all the interfaces. For now, just traverse the
// instructions in this function.
funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
auto var =
module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
// TODO: Per SPIR-V spec: "Before version 1.4, the interfaces
// storage classes are limited to the Input and Output storage classes.
// Starting with version 1.4, the interfaces storage classes are all
// storage classes used in declaring all global variables referenced by the
// entry points call tree." We should consider the target environment here.
switch (var.type().cast<spirv::PointerType>().getStorageClass()) {
case spirv::StorageClass::Input:
case spirv::StorageClass::Output:
interfaceVarSet.insert(var.getOperation());
break;
default:
break;
}
});
for (auto &var : interfaceVarSet) {
interfaceVars.push_back(SymbolRefAttr::get(
cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
}
return success();
}
/// Lowers the entry point attribute.
static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
OpBuilder &builder) {
auto entryPointAttrName = spirv::getEntryPointABIAttrName();
auto entryPointAttr =
funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
if (!entryPointAttr) {
return failure();
}
OpBuilder::InsertionGuard moduleInsertionGuard(builder);
auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>();
builder.setInsertionPoint(spirvModule.body().front().getTerminator());
// Adds the spv.EntryPointOp after collecting all the interface variables
// needed.
SmallVector<Attribute, 1> interfaceVars;
if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
return failure();
}
builder.create<spirv::EntryPointOp>(
funcOp.getLoc(), spirv::ExecutionModel::GLCompute, funcOp, interfaceVars);
// Specifies the spv.ExecutionModeOp.
auto localSizeAttr = entryPointAttr.local_size();
SmallVector<int32_t, 3> localSize(localSizeAttr.getValues<int32_t>());
builder.create<spirv::ExecutionModeOp>(
funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize);
funcOp.removeAttr(entryPointAttrName);
return success();
}
namespace {
/// A pattern to convert function signature according to interface variable ABI
/// attributes.
///
/// Specifically, this pattern creates global variables according to interface
/// variable ABI attributes attached to function arguments and converts all
/// function argument uses to those global variables. This is necessary because
/// Vulkan requires all shader entry points to be of void(void) type.
class ProcessInterfaceVarABI final : public SPIRVOpLowering<spirv::FuncOp> {
public:
using SPIRVOpLowering<spirv::FuncOp>::SPIRVOpLowering;
LogicalResult
matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Pass to implement the ABI information specified as attributes.
class LowerABIAttributesPass final
: public SPIRVLowerABIAttributesBase<LowerABIAttributesPass> {
void runOnOperation() override;
};
} // namespace
LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
spirv::FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!funcOp.getAttrOfType<spirv::EntryPointABIAttr>(
spirv::getEntryPointABIAttrName())) {
// TODO: Non-entry point functions are not handled.
return failure();
}
TypeConverter::SignatureConversion signatureConverter(
funcOp.getType().getNumInputs());
auto attrName = spirv::getInterfaceVarABIAttrName();
for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) {
auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
argType.index(), attrName);
if (!abiInfo) {
// TODO: For non-entry point functions, it should be legal
// to pass around scalar/vector values and return a scalar/vector. For now
// non-entry point functions are not handled in this ABI lowering and will
// produce an error.
return failure();
}
spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
rewriter, funcOp, argType.index(), abiInfo);
if (!var)
return failure();
OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
rewriter.setInsertionPointToStart(&funcOp.front());
// Insert spirv::AddressOf and spirv::AccessChain operations.
Value replacement =
rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
// Check if the arg is a scalar or vector type. In that case, the value
// needs to be loaded into registers.
// TODO: This is loading value of the scalar into registers
// at the start of the function. It is probably better to do the load just
// before the use. There might be multiple loads and currently there is no
// easy way to replace all uses with a sequence of operations.
if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
auto indexType = SPIRVTypeConverter::getIndexType(funcOp.getContext());
auto zero =
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
funcOp.getLoc(), replacement, zero.constant());
replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
}
signatureConverter.remapInput(argType.index(), replacement);
}
if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), typeConverter,
&signatureConverter)))
return failure();
// Creates a new function with the update signature.
rewriter.updateRootInPlace(funcOp, [&] {
funcOp.setType(rewriter.getFunctionType(
signatureConverter.getConvertedTypes(), llvm::None));
});
return success();
}
void LowerABIAttributesPass::runOnOperation() {
// Uses the signature conversion methodology of the dialect conversion
// framework to implement the conversion.
spirv::ModuleOp module = getOperation();
MLIRContext *context = &getContext();
spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module));
SPIRVTypeConverter typeConverter(targetEnv);
OwningRewritePatternList patterns;
patterns.insert<ProcessInterfaceVarABI>(context, typeConverter);
ConversionTarget target(*context);
// "Legal" function ops should have no interface variable ABI attributes.
target.addDynamicallyLegalOp<spirv::FuncOp>([&](spirv::FuncOp op) {
StringRef attrName = spirv::getInterfaceVarABIAttrName();
for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
if (op.getArgAttr(i, attrName))
return false;
return true;
});
// All other SPIR-V ops are legal.
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return op->getDialect()->getNamespace() ==
spirv::SPIRVDialect::getDialectNamespace();
});
if (failed(applyPartialConversion(module, target, patterns)))
return signalPassFailure();
// Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
// attributes.
OpBuilder builder(context);
SmallVector<spirv::FuncOp, 1> entryPointFns;
auto entryPointAttrName = spirv::getEntryPointABIAttrName();
module.walk([&](spirv::FuncOp funcOp) {
if (funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
entryPointFns.push_back(funcOp);
}
});
for (auto fn : entryPointFns) {
if (failed(lowerEntryPointABIAttr(fn, builder))) {
return signalPassFailure();
}
}
}
std::unique_ptr<OperationPass<spirv::ModuleOp>>
mlir::spirv::createLowerABIAttributesPass() {
return std::make_unique<LowerABIAttributesPass>();
}