From cc980aa41651c2cbfcbd9048fb0788f4aa9ae475 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Thu, 15 Aug 2019 10:54:22 -0700 Subject: [PATCH] Simplify the classes that support SPIR-V conversion. Modify the Type converters to have a SPIRVBasicTypeConverter which only handles conversion from standard types to SPIRV types. Rename SPIRVEntryFnConverter to SPIRVTypeConverter. This contains the SPIRVBasicTypeConverter within it. Remove SPIRVFnLowering class and have separate utility methods to lower a function as entry function or a non-entry function. The current setup could end with diamond inheritence that is not very friendly to use. For example, you could define the following Op conversion methods that lower from a dialect "Foo" which resuls in diamond inheritance. template class FooDialect : public SPIRVOpLowering {...}; class FooFnLowering : public FooDialect, SPIRVFnLowering {...}; PiperOrigin-RevId: 263597101 --- .../StandardToSPIRV/ConvertStandardToSPIRV.h | 59 ++++++++++--------- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 21 ++++--- .../ConvertStandardToSPIRV.cpp | 45 +++++++------- 3 files changed, 66 insertions(+), 59 deletions(-) diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h index 21c2842cf130..adfd83b3f64f 100644 --- a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h +++ b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h @@ -33,12 +33,12 @@ class SPIRVDialect; } /// Type conversion from Standard Types to SPIR-V Types. -class SPIRVTypeConverter : public TypeConverter { +class SPIRVBasicTypeConverter : public TypeConverter { public: - explicit SPIRVTypeConverter(MLIRContext *context); + explicit SPIRVBasicTypeConverter(MLIRContext *context); /// Converts types to SPIR-V supported types. - Type convertType(Type t) override; + virtual Type convertType(Type t); protected: spirv::SPIRVDialect *spirvDialect; @@ -47,51 +47,54 @@ protected: /// Converts a function type according to the requirements of a SPIR-V entry /// function. The arguments need to be converted to spv.Variables of spv.ptr /// types so that they could be bound by the runtime. -class SPIRVEntryFnTypeConverter final : public SPIRVTypeConverter { +class SPIRVTypeConverter final : public TypeConverter { public: - using SPIRVTypeConverter::SPIRVTypeConverter; + explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter) + : basicTypeConverter(basicTypeConverter) {} + + /// Convert types to SPIR-V types using the basic type converter. + Type convertType(Type t) override { + return basicTypeConverter->convertType(t); + } /// Method to convert argument of a function. The `type` is converted to /// spv.ptr. // TODO(ravishankarm) : Support other storage classes. LogicalResult convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) override; + + /// Get the basic type converter. + SPIRVBasicTypeConverter *getBasicTypeConverter() const { + return basicTypeConverter; + } + +private: + SPIRVBasicTypeConverter *basicTypeConverter; }; /// Base class to define a conversion pattern to translate Ops into SPIR-V. template class SPIRVOpLowering : public ConversionPattern { public: - SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter, - SPIRVEntryFnTypeConverter &entryFnConverter) + SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter) : ConversionPattern(OpTy::getOperationName(), 1, context), - typeConverter(typeConverter), entryFnConverter(entryFnConverter) {} + typeConverter(typeConverter) {} protected: // Type lowering class. SPIRVTypeConverter &typeConverter; - - // Entry function signature converter. - SPIRVEntryFnTypeConverter &entryFnConverter; }; -/// Base Class for legalize a FuncOp within a spv.module. This class can be -/// extended to implement a ConversionPattern to lower a FuncOp. It provides -/// hooks to legalize a FuncOp as a simple function, or as an entry function. -class SPIRVFnLowering : public SPIRVOpLowering { -public: - using SPIRVOpLowering::SPIRVOpLowering; +/// Method to legalize a function as a non-entry function. +LogicalResult lowerFunction(FuncOp funcOp, ArrayRef operands, + SPIRVTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FuncOp &newFuncOp); -protected: - /// Method to legalize the function as a non-entry function. - LogicalResult lowerFunction(FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter, - FuncOp &newFuncOp) const; - - /// Method to legalize the function as an entry function. - LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter, - FuncOp &newFuncOp) const; -}; +/// Method to legalize a function as an entry function. +LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef operands, + SPIRVTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FuncOp &newFuncOp); /// Appends to a pattern list additional patterns for translating StandardOps to /// SPIR-V ops. diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index c36aee5d62bc..ff6af83b9bea 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -31,9 +31,9 @@ namespace { /// Pattern to convert a kernel function in GPU dialect (a FuncOp with the /// attribute gpu.kernel) within a spv.module. -class KernelFnConversion final : public SPIRVFnLowering { +class KernelFnConversion final : public SPIRVOpLowering { public: - using SPIRVFnLowering::SPIRVFnLowering; + using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -47,12 +47,14 @@ KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef operands, auto funcOp = cast(op); FuncOp newFuncOp; if (!gpu::GPUDialect::isKernel(funcOp)) { - return succeeded(lowerFunction(funcOp, operands, rewriter, newFuncOp)) + return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter, + newFuncOp)) ? matchSuccess() : matchFailure(); } - if (failed(lowerAsEntryFunction(funcOp, operands, rewriter, newFuncOp))) { + if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter, + newFuncOp))) { return matchFailure(); } newFuncOp.getOperation()->removeAttr(Identifier::get( @@ -101,16 +103,17 @@ void GPUToSPIRVPass::runOnModule() { } /// Dialect conversion to lower the functions with the spirv::ModuleOps. - SPIRVTypeConverter typeConverter(context); - SPIRVEntryFnTypeConverter entryFnConverter(context); + SPIRVBasicTypeConverter basicTypeConverter(context); + SPIRVTypeConverter typeConverter(&basicTypeConverter); OwningRewritePatternList patterns; - patterns.insert(context, typeConverter, entryFnConverter); + patterns.insert(context, typeConverter); populateStandardToSPIRVPatterns(context, patterns); ConversionTarget target(*context); target.addLegalDialect(); - target.addDynamicallyLegalOp( - [&](FuncOp Op) { return typeConverter.isSignatureLegal(Op.getType()); }); + target.addDynamicallyLegalOp([&](FuncOp Op) { + return basicTypeConverter.isSignatureLegal(Op.getType()); + }); if (failed(applyFullConversion(spirvModules, target, patterns, &typeConverter))) { diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index 067f2aeda06d..53a40dfa365e 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -30,10 +30,10 @@ using namespace mlir; // Type Conversion //===----------------------------------------------------------------------===// -SPIRVTypeConverter::SPIRVTypeConverter(MLIRContext *context) +SPIRVBasicTypeConverter::SPIRVBasicTypeConverter(MLIRContext *context) : spirvDialect(context->getRegisteredDialect()) {} -Type SPIRVTypeConverter::convertType(Type t) { +Type SPIRVBasicTypeConverter::convertType(Type t) { // Check if the type is SPIR-V supported. If so return the type. if (spirvDialect->isValidSPIRVType(t)) { return t; @@ -58,10 +58,10 @@ Type SPIRVTypeConverter::convertType(Type t) { //===----------------------------------------------------------------------===// LogicalResult -SPIRVEntryFnTypeConverter::convertSignatureArg(unsigned inputNo, Type type, - SignatureConversion &result) { +SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type, + SignatureConversion &result) { // Try to convert the given input type. - auto convertedType = convertType(type); + auto convertedType = basicTypeConverter->convertType(type); // TODO(ravishankarm) : Vulkan spec requires these to be a // spirv::StructType. This is not a SPIR-V requirement, so just making this a // pointer type for now. @@ -81,12 +81,10 @@ SPIRVEntryFnTypeConverter::convertSignatureArg(unsigned inputNo, Type type, return success(); } -template -static LogicalResult -lowerFunctionImpl(FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter, Converter &typeConverter, - TypeConverter::SignatureConversion &signatureConverter, - FuncOp &newFuncOp) { +static LogicalResult lowerFunctionImpl( + FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter, TypeConverter *typeConverter, + TypeConverter::SignatureConversion &signatureConverter, FuncOp &newFuncOp) { auto fnType = funcOp.getType(); if (fnType.getNumResults()) { @@ -96,7 +94,7 @@ lowerFunctionImpl(FuncOp funcOp, ArrayRef operands, for (auto &argType : enumerate(fnType.getInputs())) { // Get the type of the argument - if (failed(typeConverter.convertSignatureArg( + if (failed(typeConverter->convertSignatureArg( argType.index(), argType.value(), signatureConverter))) { return funcOp.emitError("unable to convert argument type ") << argType.value() << " to SPIR-V type"; @@ -116,23 +114,25 @@ lowerFunctionImpl(FuncOp funcOp, ArrayRef operands, return success(); } -LogicalResult -SPIRVFnLowering::lowerFunction(FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter, - FuncOp &newFuncOp) const { +namespace mlir { +LogicalResult lowerFunction(FuncOp funcOp, ArrayRef operands, + SPIRVTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FuncOp &newFuncOp) { auto fnType = funcOp.getType(); TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); - return lowerFunctionImpl(funcOp, operands, rewriter, typeConverter, + return lowerFunctionImpl(funcOp, operands, rewriter, + typeConverter->getBasicTypeConverter(), signatureConverter, newFuncOp); } -LogicalResult -SPIRVFnLowering::lowerAsEntryFunction(FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter, - FuncOp &newFuncOp) const { +LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef operands, + SPIRVTypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + FuncOp &newFuncOp) { auto fnType = funcOp.getType(); TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); - if (failed(lowerFunctionImpl(funcOp, operands, rewriter, entryFnConverter, + if (failed(lowerFunctionImpl(funcOp, operands, rewriter, typeConverter, signatureConverter, newFuncOp))) { return failure(); } @@ -167,6 +167,7 @@ SPIRVFnLowering::lowerAsEntryFunction(FuncOp funcOp, ArrayRef operands, builder.getSymbolRefAttr(newFuncOp.getName()), interface); return success(); } +} // namespace mlir //===----------------------------------------------------------------------===// // Operation conversion