Simplify the classes that support SPIR-V conversion.

Modify the Type converters to have a SPIRVBasicTypeConverter which
only handles conversion from standard types to SPIRV types. Rename
SPIRVEntryFnConverter to SPIRVTypeConverter. This contains the
SPIRVBasicTypeConverter within it.

Remove SPIRVFnLowering class and have separate utility methods to
lower a function as entry function or a non-entry function. The
current setup could end with diamond inheritence that is not very
friendly to use.  For example, you could define the following Op
conversion methods that lower from a dialect "Foo" which resuls in
diamond inheritance.

template<typename OpTy>
class FooDialect : public SPIRVOpLowering<OpTy> {...};
class FooFnLowering : public FooDialect, SPIRVFnLowering {...};

PiperOrigin-RevId: 263597101
This commit is contained in:
Mahesh Ravishankar 2019-08-15 10:54:22 -07:00 committed by A. Unique TensorFlower
parent d71915420b
commit cc980aa416
3 changed files with 66 additions and 59 deletions

View File

@ -33,12 +33,12 @@ class SPIRVDialect;
}
/// Type conversion from Standard Types to SPIR-V Types.
class SPIRVTypeConverter : public TypeConverter {
class SPIRVBasicTypeConverter : public TypeConverter {
public:
explicit SPIRVTypeConverter(MLIRContext *context);
explicit SPIRVBasicTypeConverter(MLIRContext *context);
/// Converts types to SPIR-V supported types.
Type convertType(Type t) override;
virtual Type convertType(Type t);
protected:
spirv::SPIRVDialect *spirvDialect;
@ -47,51 +47,54 @@ protected:
/// Converts a function type according to the requirements of a SPIR-V entry
/// function. The arguments need to be converted to spv.Variables of spv.ptr
/// types so that they could be bound by the runtime.
class SPIRVEntryFnTypeConverter final : public SPIRVTypeConverter {
class SPIRVTypeConverter final : public TypeConverter {
public:
using SPIRVTypeConverter::SPIRVTypeConverter;
explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
: basicTypeConverter(basicTypeConverter) {}
/// Convert types to SPIR-V types using the basic type converter.
Type convertType(Type t) override {
return basicTypeConverter->convertType(t);
}
/// Method to convert argument of a function. The `type` is converted to
/// spv.ptr<type, Uniform>.
// TODO(ravishankarm) : Support other storage classes.
LogicalResult convertSignatureArg(unsigned inputNo, Type type,
SignatureConversion &result) override;
/// Get the basic type converter.
SPIRVBasicTypeConverter *getBasicTypeConverter() const {
return basicTypeConverter;
}
private:
SPIRVBasicTypeConverter *basicTypeConverter;
};
/// Base class to define a conversion pattern to translate Ops into SPIR-V.
template <typename OpTy> class SPIRVOpLowering : public ConversionPattern {
public:
SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter,
SPIRVEntryFnTypeConverter &entryFnConverter)
SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter)
: ConversionPattern(OpTy::getOperationName(), 1, context),
typeConverter(typeConverter), entryFnConverter(entryFnConverter) {}
typeConverter(typeConverter) {}
protected:
// Type lowering class.
SPIRVTypeConverter &typeConverter;
// Entry function signature converter.
SPIRVEntryFnTypeConverter &entryFnConverter;
};
/// Base Class for legalize a FuncOp within a spv.module. This class can be
/// extended to implement a ConversionPattern to lower a FuncOp. It provides
/// hooks to legalize a FuncOp as a simple function, or as an entry function.
class SPIRVFnLowering : public SPIRVOpLowering<FuncOp> {
public:
using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
/// Method to legalize a function as a non-entry function.
LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp);
protected:
/// Method to legalize the function as a non-entry function.
LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp) const;
/// Method to legalize the function as an entry function.
LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp) const;
};
/// Method to legalize a function as an entry function.
LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp);
/// Appends to a pattern list additional patterns for translating StandardOps to
/// SPIR-V ops.

View File

@ -31,9 +31,9 @@ namespace {
/// Pattern to convert a kernel function in GPU dialect (a FuncOp with the
/// attribute gpu.kernel) within a spv.module.
class KernelFnConversion final : public SPIRVFnLowering {
class KernelFnConversion final : public SPIRVOpLowering<FuncOp> {
public:
using SPIRVFnLowering::SPIRVFnLowering;
using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
@ -47,12 +47,14 @@ KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
auto funcOp = cast<FuncOp>(op);
FuncOp newFuncOp;
if (!gpu::GPUDialect::isKernel(funcOp)) {
return succeeded(lowerFunction(funcOp, operands, rewriter, newFuncOp))
return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter,
newFuncOp))
? matchSuccess()
: matchFailure();
}
if (failed(lowerAsEntryFunction(funcOp, operands, rewriter, newFuncOp))) {
if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter,
newFuncOp))) {
return matchFailure();
}
newFuncOp.getOperation()->removeAttr(Identifier::get(
@ -101,16 +103,17 @@ void GPUToSPIRVPass::runOnModule() {
}
/// Dialect conversion to lower the functions with the spirv::ModuleOps.
SPIRVTypeConverter typeConverter(context);
SPIRVEntryFnTypeConverter entryFnConverter(context);
SPIRVBasicTypeConverter basicTypeConverter(context);
SPIRVTypeConverter typeConverter(&basicTypeConverter);
OwningRewritePatternList patterns;
patterns.insert<KernelFnConversion>(context, typeConverter, entryFnConverter);
patterns.insert<KernelFnConversion>(context, typeConverter);
populateStandardToSPIRVPatterns(context, patterns);
ConversionTarget target(*context);
target.addLegalDialect<spirv::SPIRVDialect>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp Op) { return typeConverter.isSignatureLegal(Op.getType()); });
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp Op) {
return basicTypeConverter.isSignatureLegal(Op.getType());
});
if (failed(applyFullConversion(spirvModules, target, patterns,
&typeConverter))) {

View File

@ -30,10 +30,10 @@ using namespace mlir;
// Type Conversion
//===----------------------------------------------------------------------===//
SPIRVTypeConverter::SPIRVTypeConverter(MLIRContext *context)
SPIRVBasicTypeConverter::SPIRVBasicTypeConverter(MLIRContext *context)
: spirvDialect(context->getRegisteredDialect<spirv::SPIRVDialect>()) {}
Type SPIRVTypeConverter::convertType(Type t) {
Type SPIRVBasicTypeConverter::convertType(Type t) {
// Check if the type is SPIR-V supported. If so return the type.
if (spirvDialect->isValidSPIRVType(t)) {
return t;
@ -58,10 +58,10 @@ Type SPIRVTypeConverter::convertType(Type t) {
//===----------------------------------------------------------------------===//
LogicalResult
SPIRVEntryFnTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
SignatureConversion &result) {
SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
SignatureConversion &result) {
// Try to convert the given input type.
auto convertedType = convertType(type);
auto convertedType = basicTypeConverter->convertType(type);
// TODO(ravishankarm) : Vulkan spec requires these to be a
// spirv::StructType. This is not a SPIR-V requirement, so just making this a
// pointer type for now.
@ -81,12 +81,10 @@ SPIRVEntryFnTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
return success();
}
template <typename Converter>
static LogicalResult
lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter, Converter &typeConverter,
TypeConverter::SignatureConversion &signatureConverter,
FuncOp &newFuncOp) {
static LogicalResult lowerFunctionImpl(
FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter, TypeConverter *typeConverter,
TypeConverter::SignatureConversion &signatureConverter, FuncOp &newFuncOp) {
auto fnType = funcOp.getType();
if (fnType.getNumResults()) {
@ -96,7 +94,7 @@ lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
for (auto &argType : enumerate(fnType.getInputs())) {
// Get the type of the argument
if (failed(typeConverter.convertSignatureArg(
if (failed(typeConverter->convertSignatureArg(
argType.index(), argType.value(), signatureConverter))) {
return funcOp.emitError("unable to convert argument type ")
<< argType.value() << " to SPIR-V type";
@ -116,23 +114,25 @@ lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
return success();
}
LogicalResult
SPIRVFnLowering::lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp) const {
namespace mlir {
LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp) {
auto fnType = funcOp.getType();
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
return lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
return lowerFunctionImpl(funcOp, operands, rewriter,
typeConverter->getBasicTypeConverter(),
signatureConverter, newFuncOp);
}
LogicalResult
SPIRVFnLowering::lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp) const {
LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp) {
auto fnType = funcOp.getType();
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
if (failed(lowerFunctionImpl(funcOp, operands, rewriter, entryFnConverter,
if (failed(lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
signatureConverter, newFuncOp))) {
return failure();
}
@ -167,6 +167,7 @@ SPIRVFnLowering::lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
builder.getSymbolRefAttr(newFuncOp.getName()), interface);
return success();
}
} // namespace mlir
//===----------------------------------------------------------------------===//
// Operation conversion