forked from OSchip/llvm-project
Simplify the classes that support SPIR-V conversion.
Modify the Type converters to have a SPIRVBasicTypeConverter which only handles conversion from standard types to SPIRV types. Rename SPIRVEntryFnConverter to SPIRVTypeConverter. This contains the SPIRVBasicTypeConverter within it. Remove SPIRVFnLowering class and have separate utility methods to lower a function as entry function or a non-entry function. The current setup could end with diamond inheritence that is not very friendly to use. For example, you could define the following Op conversion methods that lower from a dialect "Foo" which resuls in diamond inheritance. template<typename OpTy> class FooDialect : public SPIRVOpLowering<OpTy> {...}; class FooFnLowering : public FooDialect, SPIRVFnLowering {...}; PiperOrigin-RevId: 263597101
This commit is contained in:
parent
d71915420b
commit
cc980aa416
|
@ -33,12 +33,12 @@ class SPIRVDialect;
|
|||
}
|
||||
|
||||
/// Type conversion from Standard Types to SPIR-V Types.
|
||||
class SPIRVTypeConverter : public TypeConverter {
|
||||
class SPIRVBasicTypeConverter : public TypeConverter {
|
||||
public:
|
||||
explicit SPIRVTypeConverter(MLIRContext *context);
|
||||
explicit SPIRVBasicTypeConverter(MLIRContext *context);
|
||||
|
||||
/// Converts types to SPIR-V supported types.
|
||||
Type convertType(Type t) override;
|
||||
virtual Type convertType(Type t);
|
||||
|
||||
protected:
|
||||
spirv::SPIRVDialect *spirvDialect;
|
||||
|
@ -47,51 +47,54 @@ protected:
|
|||
/// Converts a function type according to the requirements of a SPIR-V entry
|
||||
/// function. The arguments need to be converted to spv.Variables of spv.ptr
|
||||
/// types so that they could be bound by the runtime.
|
||||
class SPIRVEntryFnTypeConverter final : public SPIRVTypeConverter {
|
||||
class SPIRVTypeConverter final : public TypeConverter {
|
||||
public:
|
||||
using SPIRVTypeConverter::SPIRVTypeConverter;
|
||||
explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
|
||||
: basicTypeConverter(basicTypeConverter) {}
|
||||
|
||||
/// Convert types to SPIR-V types using the basic type converter.
|
||||
Type convertType(Type t) override {
|
||||
return basicTypeConverter->convertType(t);
|
||||
}
|
||||
|
||||
/// Method to convert argument of a function. The `type` is converted to
|
||||
/// spv.ptr<type, Uniform>.
|
||||
// TODO(ravishankarm) : Support other storage classes.
|
||||
LogicalResult convertSignatureArg(unsigned inputNo, Type type,
|
||||
SignatureConversion &result) override;
|
||||
|
||||
/// Get the basic type converter.
|
||||
SPIRVBasicTypeConverter *getBasicTypeConverter() const {
|
||||
return basicTypeConverter;
|
||||
}
|
||||
|
||||
private:
|
||||
SPIRVBasicTypeConverter *basicTypeConverter;
|
||||
};
|
||||
|
||||
/// Base class to define a conversion pattern to translate Ops into SPIR-V.
|
||||
template <typename OpTy> class SPIRVOpLowering : public ConversionPattern {
|
||||
public:
|
||||
SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter,
|
||||
SPIRVEntryFnTypeConverter &entryFnConverter)
|
||||
SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter)
|
||||
: ConversionPattern(OpTy::getOperationName(), 1, context),
|
||||
typeConverter(typeConverter), entryFnConverter(entryFnConverter) {}
|
||||
typeConverter(typeConverter) {}
|
||||
|
||||
protected:
|
||||
// Type lowering class.
|
||||
SPIRVTypeConverter &typeConverter;
|
||||
|
||||
// Entry function signature converter.
|
||||
SPIRVEntryFnTypeConverter &entryFnConverter;
|
||||
};
|
||||
|
||||
/// Base Class for legalize a FuncOp within a spv.module. This class can be
|
||||
/// extended to implement a ConversionPattern to lower a FuncOp. It provides
|
||||
/// hooks to legalize a FuncOp as a simple function, or as an entry function.
|
||||
class SPIRVFnLowering : public SPIRVOpLowering<FuncOp> {
|
||||
public:
|
||||
using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
|
||||
/// Method to legalize a function as a non-entry function.
|
||||
LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
||||
SPIRVTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
FuncOp &newFuncOp);
|
||||
|
||||
protected:
|
||||
/// Method to legalize the function as a non-entry function.
|
||||
LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
FuncOp &newFuncOp) const;
|
||||
|
||||
/// Method to legalize the function as an entry function.
|
||||
LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
FuncOp &newFuncOp) const;
|
||||
};
|
||||
/// Method to legalize a function as an entry function.
|
||||
LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
||||
SPIRVTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
FuncOp &newFuncOp);
|
||||
|
||||
/// Appends to a pattern list additional patterns for translating StandardOps to
|
||||
/// SPIR-V ops.
|
||||
|
|
|
@ -31,9 +31,9 @@ namespace {
|
|||
|
||||
/// Pattern to convert a kernel function in GPU dialect (a FuncOp with the
|
||||
/// attribute gpu.kernel) within a spv.module.
|
||||
class KernelFnConversion final : public SPIRVFnLowering {
|
||||
class KernelFnConversion final : public SPIRVOpLowering<FuncOp> {
|
||||
public:
|
||||
using SPIRVFnLowering::SPIRVFnLowering;
|
||||
using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
|
@ -47,12 +47,14 @@ KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
|||
auto funcOp = cast<FuncOp>(op);
|
||||
FuncOp newFuncOp;
|
||||
if (!gpu::GPUDialect::isKernel(funcOp)) {
|
||||
return succeeded(lowerFunction(funcOp, operands, rewriter, newFuncOp))
|
||||
return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter,
|
||||
newFuncOp))
|
||||
? matchSuccess()
|
||||
: matchFailure();
|
||||
}
|
||||
|
||||
if (failed(lowerAsEntryFunction(funcOp, operands, rewriter, newFuncOp))) {
|
||||
if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter,
|
||||
newFuncOp))) {
|
||||
return matchFailure();
|
||||
}
|
||||
newFuncOp.getOperation()->removeAttr(Identifier::get(
|
||||
|
@ -101,16 +103,17 @@ void GPUToSPIRVPass::runOnModule() {
|
|||
}
|
||||
|
||||
/// Dialect conversion to lower the functions with the spirv::ModuleOps.
|
||||
SPIRVTypeConverter typeConverter(context);
|
||||
SPIRVEntryFnTypeConverter entryFnConverter(context);
|
||||
SPIRVBasicTypeConverter basicTypeConverter(context);
|
||||
SPIRVTypeConverter typeConverter(&basicTypeConverter);
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<KernelFnConversion>(context, typeConverter, entryFnConverter);
|
||||
patterns.insert<KernelFnConversion>(context, typeConverter);
|
||||
populateStandardToSPIRVPatterns(context, patterns);
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<spirv::SPIRVDialect>();
|
||||
target.addDynamicallyLegalOp<FuncOp>(
|
||||
[&](FuncOp Op) { return typeConverter.isSignatureLegal(Op.getType()); });
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp Op) {
|
||||
return basicTypeConverter.isSignatureLegal(Op.getType());
|
||||
});
|
||||
|
||||
if (failed(applyFullConversion(spirvModules, target, patterns,
|
||||
&typeConverter))) {
|
||||
|
|
|
@ -30,10 +30,10 @@ using namespace mlir;
|
|||
// Type Conversion
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SPIRVTypeConverter::SPIRVTypeConverter(MLIRContext *context)
|
||||
SPIRVBasicTypeConverter::SPIRVBasicTypeConverter(MLIRContext *context)
|
||||
: spirvDialect(context->getRegisteredDialect<spirv::SPIRVDialect>()) {}
|
||||
|
||||
Type SPIRVTypeConverter::convertType(Type t) {
|
||||
Type SPIRVBasicTypeConverter::convertType(Type t) {
|
||||
// Check if the type is SPIR-V supported. If so return the type.
|
||||
if (spirvDialect->isValidSPIRVType(t)) {
|
||||
return t;
|
||||
|
@ -58,10 +58,10 @@ Type SPIRVTypeConverter::convertType(Type t) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
SPIRVEntryFnTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
|
||||
SignatureConversion &result) {
|
||||
SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
|
||||
SignatureConversion &result) {
|
||||
// Try to convert the given input type.
|
||||
auto convertedType = convertType(type);
|
||||
auto convertedType = basicTypeConverter->convertType(type);
|
||||
// TODO(ravishankarm) : Vulkan spec requires these to be a
|
||||
// spirv::StructType. This is not a SPIR-V requirement, so just making this a
|
||||
// pointer type for now.
|
||||
|
@ -81,12 +81,10 @@ SPIRVEntryFnTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename Converter>
|
||||
static LogicalResult
|
||||
lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter, Converter &typeConverter,
|
||||
TypeConverter::SignatureConversion &signatureConverter,
|
||||
FuncOp &newFuncOp) {
|
||||
static LogicalResult lowerFunctionImpl(
|
||||
FuncOp funcOp, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter, TypeConverter *typeConverter,
|
||||
TypeConverter::SignatureConversion &signatureConverter, FuncOp &newFuncOp) {
|
||||
auto fnType = funcOp.getType();
|
||||
|
||||
if (fnType.getNumResults()) {
|
||||
|
@ -96,7 +94,7 @@ lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
|
|||
|
||||
for (auto &argType : enumerate(fnType.getInputs())) {
|
||||
// Get the type of the argument
|
||||
if (failed(typeConverter.convertSignatureArg(
|
||||
if (failed(typeConverter->convertSignatureArg(
|
||||
argType.index(), argType.value(), signatureConverter))) {
|
||||
return funcOp.emitError("unable to convert argument type ")
|
||||
<< argType.value() << " to SPIR-V type";
|
||||
|
@ -116,23 +114,25 @@ lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
SPIRVFnLowering::lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
FuncOp &newFuncOp) const {
|
||||
namespace mlir {
|
||||
LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
||||
SPIRVTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
FuncOp &newFuncOp) {
|
||||
auto fnType = funcOp.getType();
|
||||
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
|
||||
return lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
|
||||
return lowerFunctionImpl(funcOp, operands, rewriter,
|
||||
typeConverter->getBasicTypeConverter(),
|
||||
signatureConverter, newFuncOp);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
SPIRVFnLowering::lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
FuncOp &newFuncOp) const {
|
||||
LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
||||
SPIRVTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
FuncOp &newFuncOp) {
|
||||
auto fnType = funcOp.getType();
|
||||
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
|
||||
if (failed(lowerFunctionImpl(funcOp, operands, rewriter, entryFnConverter,
|
||||
if (failed(lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
|
||||
signatureConverter, newFuncOp))) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -167,6 +167,7 @@ SPIRVFnLowering::lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
|
|||
builder.getSymbolRefAttr(newFuncOp.getName()), interface);
|
||||
return success();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operation conversion
|
||||
|
|
Loading…
Reference in New Issue