[mlir] Add support for generating dialect declarations via tablegen.

Summary: This generates the class declarations for dialects using the existing 'Dialect' tablegen classes.

Differential Revision: https://reviews.llvm.org/D76185
This commit is contained in:
River Riddle 2020-03-14 20:33:53 -07:00
parent 27f303924e
commit 429d792f23
48 changed files with 388 additions and 281 deletions

View File

@ -28,10 +28,11 @@ function(whole_archive_link target)
endfunction(whole_archive_link)
# Declare a dialect in the include directory
function(add_mlir_dialect dialect dialect_doc_filename)
function(add_mlir_dialect dialect dialect_namespace dialect_doc_filename)
set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
mlir_tablegen(${dialect}.h.inc -gen-op-decls)
mlir_tablegen(${dialect}.cpp.inc -gen-op-defs)
mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace})
add_public_tablegen_target(MLIR${dialect}IncGen)
add_dependencies(mlir-headers MLIR${dialect}IncGen)

View File

@ -39,7 +39,7 @@ is declared using add_mlir_dialect().
```cmake
add_mlir_dialect(FooOps FooOps)
add_mlir_dialect(FooOps foo FooOps)
```

View File

@ -36,17 +36,6 @@ class OpBuilder;
/// symbol.
bool isTopLevelValue(Value value);
class AffineOpsDialect : public Dialect {
public:
AffineOpsDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "affine"; }
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override;
};
/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
/// from a source memref to a destination memref. The source and destination
/// memref need not be of the same dimensionality, but need to have the same
@ -504,6 +493,8 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
void fullyComposeAffineMapAndOperands(AffineMap *map,
SmallVectorImpl<Value> *operands);
#include "mlir/Dialect/AffineOps/AffineOpsDialect.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/AffineOps/AffineOps.h.inc"

View File

@ -17,14 +17,15 @@ include "mlir/Dialect/AffineOps/AffineOpsBase.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffects.td"
def Affine_Dialect : Dialect {
def AffineOps_Dialect : Dialect {
let name = "affine";
let cppNamespace = "";
let hasConstantMaterializer = 1;
}
// Base class for Affine dialect ops.
class Affine_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Affine_Dialect, mnemonic, traits> {
Op<AffineOps_Dialect, mnemonic, traits> {
// For every affine op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * LogicalResult verify(${C++ class of Op} op)
@ -290,7 +291,7 @@ def AffineIfOp : Affine_Op<"if",
}
class AffineMinMaxOpBase<string mnemonic, list<OpTrait> traits = []> :
Op<Affine_Dialect, mnemonic, traits> {
Op<AffineOps_Dialect, mnemonic, traits> {
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$operands);
let results = (outs Index);

View File

@ -1 +1 @@
add_mlir_dialect(AffineOps AffineOps)
add_mlir_dialect(AffineOps affine AffineOps)

View File

@ -1 +1 @@
add_mlir_dialect(FxpMathOps FxpMathOps)
add_mlir_dialect(FxpMathOps fxpmath FxpMathOps)

View File

@ -17,11 +17,7 @@
namespace mlir {
namespace fxpmath {
/// Defines the 'FxpMathOps' dialect.
class FxpMathOpsDialect : public Dialect {
public:
FxpMathOpsDialect(MLIRContext *context);
};
#include "mlir/Dialect/FxpMathOps/FxpMathOpsDialect.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h.inc"

View File

@ -15,10 +15,10 @@
#define DIALECT_FXPMATHOPS_FXPMATH_OPS_
include "mlir/IR/OpBase.td"
include "mlir/Dialect/QuantOps/QuantPredicates.td"
include "mlir/Dialect/QuantOps/QuantOpsBase.td"
include "mlir/Interfaces/SideEffects.td"
def fxpmath_Dialect : Dialect {
def FxpMathOps_Dialect : Dialect {
let name = "fxpmath";
}
@ -78,7 +78,7 @@ def fxpmath_CompareFnAttr : StrEnumAttr<"ComparisonFn",
//===----------------------------------------------------------------------===//
class fxpmath_Op<string mnemonic, list<OpTrait> traits> :
Op<fxpmath_Dialect, mnemonic, traits>;
Op<FxpMathOps_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// Fixed-point (fxp) arithmetic ops used by kernels.

View File

@ -1 +1 @@
add_mlir_dialect(GPUOps GPUOps)
add_mlir_dialect(GPUOps gpu GPUOps)

View File

@ -26,51 +26,6 @@ class FuncOp;
namespace gpu {
/// The dialect containing GPU kernel launching operations and related
/// facilities.
class GPUDialect : public Dialect {
public:
/// Create the dialect in the given `context`.
explicit GPUDialect(MLIRContext *context);
/// Get dialect namespace.
static StringRef getDialectNamespace() { return "gpu"; }
/// Get the name of the attribute used to annotate the modules that contain
/// kernel modules.
static StringRef getContainerModuleAttrName() {
return "gpu.container_module";
}
/// Get the canonical string name of the dialect.
static StringRef getDialectName();
/// Get the name of the attribute used to annotate external kernel functions.
static StringRef getKernelFuncAttrName() { return "gpu.kernel"; }
/// Get the name of the attribute used to annotate kernel modules.
static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; }
/// Returns whether the given function is a kernel function, i.e., has the
/// 'gpu.kernel' attribute.
static bool isKernel(Operation *op);
/// Returns the number of workgroup (thread, block) dimensions supported in
/// the GPU dialect.
// TODO(zinenko,herhut): consider generalizing this.
static unsigned getNumWorkgroupDimensions() { return 3; }
/// Returns the numeric value used to identify the workgroup memory address
/// space.
static unsigned getWorkgroupAddressSpace() { return 3; }
/// Returns the numeric value used to identify the private memory address
/// space.
static unsigned getPrivateAddressSpace() { return 5; }
LogicalResult verifyOperationAttribute(Operation *op,
NamedAttribute attr) override;
};
/// Utility class for the GPU dialect to represent triples of `Value`s
/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
struct KernelDim3 {
@ -79,6 +34,8 @@ struct KernelDim3 {
Value z;
};
#include "mlir/Dialect/GPU/GPUOpsDialect.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/GPU/GPUOps.h.inc"

View File

@ -28,6 +28,39 @@ def IntLikeOrLLVMInt : TypeConstraint<
def GPU_Dialect : Dialect {
let name = "gpu";
let extraClassDeclaration = [{
/// Get the name of the attribute used to annotate the modules that contain
/// kernel modules.
static StringRef getContainerModuleAttrName() {
return "gpu.container_module";
}
/// Get the name of the attribute used to annotate external kernel
/// functions.
static StringRef getKernelFuncAttrName() { return "gpu.kernel"; }
/// Get the name of the attribute used to annotate kernel modules.
static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; }
/// Returns whether the given function is a kernel function, i.e., has the
/// 'gpu.kernel' attribute.
static bool isKernel(Operation *op);
/// Returns the number of workgroup (thread, block) dimensions supported in
/// the GPU dialect.
// TODO(zinenko,herhut): consider generalizing this.
static unsigned getNumWorkgroupDimensions() { return 3; }
/// Returns the numeric value used to identify the workgroup memory address
/// space.
static unsigned getWorkgroupAddressSpace() { return 3; }
/// Returns the numeric value used to identify the private memory address
/// space.
static unsigned getPrivateAddressSpace() { return 5; }
LogicalResult verifyOperationAttribute(Operation *op,
NamedAttribute attr) override;
}];
}
class GPU_Op<string mnemonic, list<OpTrait> traits = []> :

View File

@ -1,12 +1,13 @@
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMOps.h.inc -gen-op-decls)
mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)
mlir_tablegen(LLVMOpsDialect.h.inc -gen-dialect-decls)
mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRLLVMOpsIncGen)
add_mlir_dialect(NVVMOps NVVMOps)
add_mlir_dialect(ROCDLOps ROCDLOps)
add_mlir_dialect(NVVMOps nvvm NVVMOps)
add_mlir_dialect(ROCDLOps rocdl ROCDLOps)
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)

View File

@ -201,32 +201,7 @@ private:
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMOps.h.inc"
class LLVMDialect : public Dialect {
public:
explicit LLVMDialect(MLIRContext *context);
~LLVMDialect();
static StringRef getDialectNamespace() { return "llvm"; }
llvm::LLVMContext &getLLVMContext();
llvm::Module &getLLVMModule();
/// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser) const override;
/// Print a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;
/// Verify a region argument attribute registered to this dialect.
/// Returns failure if the verification failed, success otherwise.
LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx,
unsigned argIdx,
NamedAttribute argAttr) override;
private:
friend LLVMType;
std::unique_ptr<detail::LLVMDialectImpl> impl;
};
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.h.inc"
/// Create an LLVM global containing the string "value" at the module containing
/// surrounding the insertion point of builder. Obtain the address of that

View File

@ -19,11 +19,28 @@ include "mlir/IR/OpBase.td"
def LLVM_Dialect : Dialect {
let name = "llvm";
let cppNamespace = "LLVM";
let extraClassDeclaration = [{
~LLVMDialect();
llvm::LLVMContext &getLLVMContext();
llvm::Module &getLLVMModule();
/// Verify a region argument attribute registered to this dialect.
/// Returns failure if the verification failed, success otherwise.
LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx,
unsigned argIdx,
NamedAttribute argAttr) override;
private:
friend LLVMType;
std::unique_ptr<detail::LLVMDialectImpl> impl;
}];
}
// LLVM IR type wrapped in MLIR.
def LLVM_Type : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
"LLVM dialect type">;
def LLVM_Type : DialectType<LLVM_Dialect,
CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
"LLVM dialect type">;
// Type constraint accepting only wrapped LLVM integer types.
def LLVMInt : TypeConstraint<

View File

@ -25,12 +25,7 @@ namespace NVVM {
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/NVVMOps.h.inc"
class NVVMDialect : public Dialect {
public:
explicit NVVMDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "nvvm"; }
};
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.h.inc"
} // namespace NVVM
} // namespace mlir

View File

@ -33,12 +33,7 @@ namespace ROCDL {
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/ROCDLOps.h.inc"
class ROCDLDialect : public Dialect {
public:
explicit ROCDLDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "rocdl"; }
};
#include "mlir/Dialect/LLVMIR/ROCDLOpsDialect.h.inc"
} // namespace ROCDL
} // namespace mlir

View File

@ -1,4 +1,4 @@
add_mlir_dialect(LinalgOps LinalgDoc)
add_mlir_dialect(LinalgOps linalg LinalgDoc)
set(LLVM_TARGET_DEFINITIONS LinalgStructuredOps.td)
mlir_tablegen(LinalgStructuredOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgStructuredOps.cpp.inc -gen-op-defs)

View File

@ -34,6 +34,6 @@ def Linalg_Dialect : Dialect {
// Whether a type is a RangeType.
def LinalgIsRangeTypePred : CPred<"$_self.isa<RangeType>()">;
def Range : Type<LinalgIsRangeTypePred, "range">;
def Range : DialectType<Linalg_Dialect, LinalgIsRangeTypePred, "range">;
#endif // LINALG_BASE

View File

@ -21,17 +21,7 @@ enum LinalgTypes {
LAST_USED_LINALG_TYPE = Range,
};
class LinalgDialect : public Dialect {
public:
explicit LinalgDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "linalg"; }
/// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser) const override;
/// Print a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;
};
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
/// A RangeType represents a minimal range abstraction (min, max, step).
/// It is constructed by calling the linalg.range op with three values index of

View File

@ -1 +1 @@
add_mlir_dialect(LoopOps LoopOps)
add_mlir_dialect(LoopOps loop LoopOps)

View File

@ -25,11 +25,7 @@ namespace loop {
class TerminatorOp;
class LoopOpsDialect : public Dialect {
public:
LoopOpsDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "loop"; }
};
#include "mlir/Dialect/LoopOps/LoopOpsDialect.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/LoopOps/LoopOps.h.inc"

View File

@ -16,14 +16,14 @@
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffects.td"
def Loop_Dialect : Dialect {
def LoopOps_Dialect : Dialect {
let name = "loop";
let cppNamespace = "";
}
// Base class for Loop dialect ops.
class Loop_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Loop_Dialect, mnemonic, traits> {
Op<LoopOps_Dialect, mnemonic, traits> {
// For every standard op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * LogicalResult verify(${C++ class of Op} op)

View File

@ -1 +1 @@
add_mlir_dialect(OpenMPOps OpenMPOps)
add_mlir_dialect(OpenMPOps omp OpenMPOps)

View File

@ -22,13 +22,7 @@ namespace omp {
#define GET_OP_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOps.h.inc"
class OpenMPDialect : public Dialect {
public:
explicit OpenMPDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "omp"; }
};
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc"
} // namespace omp
} // namespace mlir

View File

@ -1 +1 @@
add_mlir_dialect(QuantOps QuantOps)
add_mlir_dialect(QuantOps quant QuantOps)

View File

@ -21,17 +21,7 @@
namespace mlir {
namespace quant {
/// Defines the 'Quantization' dialect
class QuantizationDialect : public Dialect {
public:
QuantizationDialect(MLIRContext *context);
/// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser) const override;
/// Print a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;
};
#include "mlir/Dialect/QuantOps/QuantOpsDialect.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/QuantOps/QuantOps.h.inc"

View File

@ -13,20 +13,15 @@
#ifndef DIALECT_QUANTOPS_QUANT_OPS_
#define DIALECT_QUANTOPS_QUANT_OPS_
include "mlir/IR/OpBase.td"
include "mlir/Dialect/QuantOps/QuantPredicates.td"
include "mlir/Dialect/QuantOps/QuantOpsBase.td"
include "mlir/Interfaces/SideEffects.td"
def quant_Dialect : Dialect {
let name = "quant";
}
//===----------------------------------------------------------------------===//
// Base classes
//===----------------------------------------------------------------------===//
class quant_Op<string mnemonic, list<OpTrait> traits> :
Op<quant_Dialect, mnemonic, traits>;
Op<Quantization_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// Quantization casts

View File

@ -1,4 +1,4 @@
//===- QuantPredicates.td - Predicates for dialect types ---*- tablegen -*-===//
//===- QuantOpsBase.td - Quantization dialect base ---------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -10,8 +10,14 @@
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_QUANTOPS_QUANT_PREDICATES_
#define DIALECT_QUANTOPS_QUANT_PREDICATES_
#ifndef DIALECT_QUANTOPS_QUANT_OPS_BASE_
#define DIALECT_QUANTOPS_QUANT_OPS_BASE_
include "mlir/IR/OpBase.td"
def Quantization_Dialect : Dialect {
let name = "quant";
}
//===----------------------------------------------------------------------===//
// Quantization type definitions
@ -54,10 +60,12 @@ def quant_RealOrStorageValueType :
// An implementation of UniformQuantizedType.
def quant_UniformQuantizedType :
Type<CPred<"$_self.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
DialectType<Quantization_Dialect,
CPred<"$_self.isa<UniformQuantizedType>()">,
"UniformQuantizedType">;
// Predicate for detecting a container or primitive of UniformQuantizedType.
def quant_UniformQuantizedValueType :
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
#endif // DIALECT_QUANTOPS_QUANT_PREDICATES_
#endif // DIALECT_QUANTOPS_QUANT_OPS_BASE_

View File

@ -1,4 +1,4 @@
add_mlir_dialect(SPIRVOps SPIRVOps)
add_mlir_dialect(SPIRVOps spv SPIRVOps)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)

View File

@ -22,7 +22,7 @@ include "mlir/Dialect/SPIRV/SPIRVAvailability.td"
// SPIR-V dialect definitions
//===----------------------------------------------------------------------===//
def SPV_Dialect : Dialect {
def SPIRV_Dialect : Dialect {
let name = "spv";
let summary = "The SPIR-V dialect in MLIR.";
@ -46,6 +46,43 @@ def SPV_Dialect : Dialect {
}];
let cppNamespace = "spirv";
let hasConstantMaterializer = 1;
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
// Type
//===------------------------------------------------------------------===//
/// Checks if the given `type` is valid in SPIR-V dialect.
static bool isValidType(Type type);
/// Checks if the given `scalar type` is valid in SPIR-V dialect.
static bool isValidScalarType(Type type);
//===------------------------------------------------------------------===//
// Attribute
//===------------------------------------------------------------------===//
/// Returns the attribute name to use when specifying decorations on results
/// of operations.
static std::string getAttributeName(Decoration decoration);
/// Provides a hook for verifying SPIR-V dialect attributes attached to the
/// given op.
LogicalResult verifyOperationAttribute(Operation *op,
NamedAttribute attribute) override;
/// Provides a hook for verifying SPIR-V dialect attributes attached to the
/// given op's region argument.
LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIndex,
unsigned argIndex,
NamedAttribute attribute) override;
/// Provides a hook for verifying SPIR-V dialect attributes attached to the
/// given op's region result.
LogicalResult verifyRegionResultAttribute(
Operation *op, unsigned regionIndex, unsigned resultIndex,
NamedAttribute attribute) override;
}];
}
//===----------------------------------------------------------------------===//
@ -2953,7 +2990,8 @@ def SPV_SamplerUseAttr:
// SPIR-V attribute definitions
//===----------------------------------------------------------------------===//
def SPV_VerCapExtAttr : Attr<
def SPV_VerCapExtAttr : DialectAttr<
SPIRV_Dialect,
CPred<"$_self.isa<::mlir::spirv::VerCapExtAttr>()">,
"version-capability-extension attribute"> {
let storageType = "::mlir::spirv::VerCapExtAttr";
@ -2993,10 +3031,14 @@ def SPV_Vector : VectorOfLengthAndType<[2, 3, 4],
[SPV_Bool, SPV_Integer, SPV_Float]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">;
def SPV_AnyArray : Type<SPV_IsArrayType, "any SPIR-V array type">;
def SPV_AnyRTArray : Type<SPV_IsRTArrayType, "any SPIR-V runtime array type">;
def SPV_AnyStruct : Type<SPV_IsStructType, "any SPIR-V struct type">;
def SPV_AnyPtr : DialectType<SPIRV_Dialect, SPV_IsPtrType,
"any SPIR-V pointer type">;
def SPV_AnyArray : DialectType<SPIRV_Dialect, SPV_IsArrayType,
"any SPIR-V array type">;
def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
"any SPIR-V runtime array type">;
def SPV_AnyStruct : DialectType<SPIRV_Dialect, SPV_IsStructType,
"any SPIR-V struct type">;
def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
@ -3264,7 +3306,7 @@ def SPV_OpcodeAttr :
// Base class for all SPIR-V ops.
class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
Op<SPV_Dialect, mnemonic, !listconcat(traits, [
Op<SPIRV_Dialect, mnemonic, !listconcat(traits, [
// TODO(antiagainst): We don't need all of the following traits for
// every op; only the suitabble ones should be added automatically
// after ODS supports dialect-specific contents.

View File

@ -20,67 +20,7 @@ namespace spirv {
enum class Decoration : uint32_t;
class SPIRVDialect : public Dialect {
public:
explicit SPIRVDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "spv"; }
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
/// Checks if the given `type` is valid in SPIR-V dialect.
static bool isValidType(Type type);
/// Checks if the given `scalar type` is valid in SPIR-V dialect.
static bool isValidScalarType(Type type);
/// Parses a type registered to this dialect.
Type parseType(DialectAsmParser &parser) const override;
/// Prints a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;
//===--------------------------------------------------------------------===//
// Attribute
//===--------------------------------------------------------------------===//
/// Returns the attribute name to use when specifying decorations on results
/// of operations.
static std::string getAttributeName(Decoration decoration);
/// Parses an attribute registered to this dialect.
Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
/// Prints an attribute registered to this dialect.
void printAttribute(Attribute, DialectAsmPrinter &printer) const override;
/// Provides a hook for verifying SPIR-V dialect attributes attached to the
/// given op.
LogicalResult verifyOperationAttribute(Operation *op,
NamedAttribute attribute) override;
/// Provides a hook for verifying SPIR-V dialect attributes attached to the
/// given op's region argument.
LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIndex,
unsigned argIndex,
NamedAttribute attribute) override;
/// Provides a hook for verifying SPIR-V dialect attributes attached to the
/// given op's region result.
LogicalResult verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
unsigned resultIndex,
NamedAttribute attribute) override;
//===--------------------------------------------------------------------===//
// Constant
//===--------------------------------------------------------------------===//
/// Provides a hook for materializing a constant to this dialect.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override;
};
#include "mlir/Dialect/SPIRV/SPIRVOpsDialect.h.inc"
} // end namespace spirv
} // end namespace mlir

View File

@ -29,7 +29,7 @@ include "mlir/Dialect/SPIRV/SPIRVBase.td"
// 1) Descriptor Set.
// 2) Binding number.
// 3) Storage class.
def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPV_Dialect, [
def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPIRV_Dialect, [
StructFieldAttr<"descriptor_set", I32Attr>,
StructFieldAttr<"binding", I32Attr>,
StructFieldAttr<"storage_class", SPV_StorageClassAttr>
@ -38,7 +38,7 @@ def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPV_Dialect, [
// For entry functions, this attribute specifies information related to entry
// points in the generated SPIR-V module:
// 1) WorkGroup Size.
def SPV_EntryPointABIAttr : StructAttr<"EntryPointABIAttr", SPV_Dialect, [
def SPV_EntryPointABIAttr : StructAttr<"EntryPointABIAttr", SPIRV_Dialect, [
StructFieldAttr<"local_size", I32ElementsAttr>
]>;
@ -54,7 +54,7 @@ def SPV_CapabilityArrayAttr : TypedArrayAttrBase<
// See https://renderdoc.org/vkspec_chunked/chap36.html#limits for the complete
// list of limits and their explanation for the Vulkan API. The following ones
// are those affecting SPIR-V CodeGen.
def SPV_ResourceLimitsAttr : StructAttr<"ResourceLimitsAttr", SPV_Dialect, [
def SPV_ResourceLimitsAttr : StructAttr<"ResourceLimitsAttr", SPIRV_Dialect, [
StructFieldAttr<"max_compute_workgroup_invocations", I32Attr>,
StructFieldAttr<"max_compute_workgroup_size", I32ElementsAttr>
]>;

View File

@ -1 +1 @@
add_mlir_dialect(ShapeOps ShapeOps)
add_mlir_dialect(ShapeOps shape ShapeOps)

View File

@ -21,13 +21,6 @@
namespace mlir {
namespace shape {
/// This dialect contains shape inference related operations and facilities.
class ShapeDialect : public Dialect {
public:
/// Create the dialect in the given `context`.
explicit ShapeDialect(MLIRContext *context);
};
namespace ShapeTypes {
enum Kind {
Component = Type::FIRST_SHAPE_TYPE,
@ -112,6 +105,8 @@ public:
#define GET_OP_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOps.h.inc"
#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.h.inc"
} // namespace shape
} // namespace mlir

View File

@ -1,6 +1,7 @@
set(LLVM_TARGET_DEFINITIONS Ops.td)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(OpsDialect.h.inc -gen-dialect-decls)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRStandardOpsIncGen)

View File

@ -31,20 +31,11 @@ class Builder;
class FuncOp;
class OpBuilder;
class StandardOpsDialect : public Dialect {
public:
StandardOpsDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "std"; }
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override;
};
#define GET_OP_CLASSES
#include "mlir/Dialect/StandardOps/IR/Ops.h.inc"
#include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc"
/// This is a refinement of the "constant" op for the case where it is
/// returning a float value of FloatType.
///

View File

@ -18,14 +18,15 @@ include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffects.td"
def Std_Dialect : Dialect {
def StandardOps_Dialect : Dialect {
let name = "std";
let cppNamespace = "";
let hasConstantMaterializer = 1;
}
// Base class for Standard dialect ops.
class Std_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Std_Dialect, mnemonic, traits> {
Op<StandardOps_Dialect, mnemonic, traits> {
// For every standard op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * LogicalResult verify(${C++ class of Op} op)
@ -63,7 +64,7 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
// Base class for unary ops. Requires single operand and result. Individual
// classes will have `operand` accessor.
class UnaryOp<string mnemonic, list<OpTrait> traits = []> :
Op<Std_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
Op<StandardOps_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
let results = (outs AnyType);
let printer = [{
return printStandardUnaryOp(this->getOperation(), p);
@ -86,7 +87,7 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
// results to be of the same type, but does not constrain them to specific
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
Op<Std_Dialect, mnemonic,
Op<StandardOps_Dialect, mnemonic,
!listconcat(traits, [NoSideEffect, SameOperandsAndResultType])> {
let results = (outs AnyType);

View File

@ -1,4 +1,4 @@
add_mlir_dialect(VectorOps VectorOps)
add_mlir_dialect(VectorOps vector VectorOps)
set(LLVM_TARGET_DEFINITIONS VectorTransformPatterns.td)
mlir_tablegen(VectorTransformPatterns.h.inc -gen-rewriters)

View File

@ -24,18 +24,6 @@ class MLIRContext;
class OwningRewritePatternList;
namespace vector {
/// Dialect for Ops on higher-dimensional vector types.
class VectorOpsDialect : public Dialect {
public:
VectorOpsDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "vector"; }
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override;
};
/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context);
@ -75,6 +63,8 @@ ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
#define GET_OP_CLASSES
#include "mlir/Dialect/VectorOps/VectorOps.h.inc"
#include "mlir/Dialect/VectorOps/VectorOpsDialect.h.inc"
} // end namespace vector
} // end namespace mlir

View File

@ -16,14 +16,15 @@
include "mlir/Dialect/AffineOps/AffineOpsBase.td"
include "mlir/Interfaces/SideEffects.td"
def Vector_Dialect : Dialect {
def VectorOps_Dialect : Dialect {
let name = "vector";
let cppNamespace = "vector";
let hasConstantMaterializer = 1;
}
// Base class for Vector dialect ops.
class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Vector_Dialect, mnemonic, traits> {
Op<VectorOps_Dialect, mnemonic, traits> {
// For every vector op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * LogicalResult verify(${C++ class of Op} op)
@ -432,7 +433,7 @@ def Vector_ExtractSlicesOp :
}
def Vector_FMAOp :
Op<Vector_Dialect, "fma", [NoSideEffect,
Op<VectorOps_Dialect, "fma", [NoSideEffect,
AllTypesMatch<["lhs", "rhs", "acc", "result"]>]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>,
Results<(outs AnyVector:$result)> {

View File

@ -253,6 +253,13 @@ class Dialect {
// the generated files are included into the dialect, you may want to specify
// a full namespace path or a partial one.
string cppNamespace = name;
// An optional code block containing extra declarations to place in the
// dialect declaration.
code extraClassDeclaration = "";
// If this dialect overrides the hook for materializing constants.
bit hasConstantMaterializer = 0;
}
//===----------------------------------------------------------------------===//
@ -753,6 +760,12 @@ class Attr<Pred condition, string descr = ""> :
Attr baseAttr = ?;
}
// An attribute of a specific dialect.
class DialectAttr<Dialect d, Pred condition, string descr = ""> :
Attr<condition, descr> {
Dialect dialect = d;
}
//===----------------------------------------------------------------------===//
// Attribute modifier definition

View File

@ -25,6 +25,7 @@ class Record;
namespace mlir {
namespace tblgen {
class Dialect;
class Type;
// Wrapper class with helper methods for accessing attribute constraints defined
@ -105,6 +106,9 @@ public:
// Returns the code body for derived attribute. Aborts if this is not a
// derived attribute.
StringRef getDerivedCodeBody() const;
// Returns the dialect for the attribute if defined.
Dialect getDialect() const;
};
// Wrapper class providing helper methods for accessing MLIR constant attribute

View File

@ -32,6 +32,9 @@ public:
// Returns the C++ namespaces that ops of this dialect should be placed into.
StringRef getCppNamespace() const;
// Returns this dialect's C++ class name.
std::string getCppClassName() const;
// Returns the summary description of the dialect. Returns empty string if
// none.
StringRef getSummary() const;
@ -39,6 +42,12 @@ public:
// Returns the description of the dialect. Returns empty string if none.
StringRef getDescription() const;
// Returns the dialects extra class declaration code.
llvm::Optional<StringRef> getExtraClassDeclaration() const;
// Returns if this dialect has a constant materializer or not.
bool hasConstantMaterializer() const;
// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;

View File

@ -28,15 +28,13 @@ using namespace mlir::gpu;
// GPUDialect
//===----------------------------------------------------------------------===//
StringRef GPUDialect::getDialectName() { return "gpu"; }
bool GPUDialect::isKernel(Operation *op) {
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
return static_cast<bool>(isKernelAttr);
}
GPUDialect::GPUDialect(MLIRContext *context)
: Dialect(getDialectName(), context) {
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"

View File

@ -132,6 +132,10 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const {
return def->getValueAsString("body");
}
tblgen::Dialect tblgen::Attribute::getDialect() const {
return Dialect(def->getValueAsDef("dialect"));
}
tblgen::ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
assert(def->isSubClassOf("ConstantAttr") &&
"must be subclass of TableGen 'ConstantAttr' class");

View File

@ -24,6 +24,13 @@ StringRef tblgen::Dialect::getCppNamespace() const {
return def->getValueAsString("cppNamespace");
}
std::string tblgen::Dialect::getCppClassName() const {
// Simply use the name and remove any '_' tokens.
std::string cppName = def->getName().str();
llvm::erase_if(cppName, [](char c) { return c == '_'; });
return cppName;
}
static StringRef getAsStringOrEmpty(const llvm::Record &record,
StringRef fieldName) {
if (auto valueInit = record.getValueInit(fieldName)) {
@ -42,6 +49,15 @@ StringRef tblgen::Dialect::getDescription() const {
return getAsStringOrEmpty(*def, "description");
}
llvm::Optional<StringRef> tblgen::Dialect::getExtraClassDeclaration() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
bool tblgen::Dialect::hasConstantMaterializer() const {
return def->getValueAsBit("hasConstantMaterializer");
}
bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}

View File

@ -4,6 +4,7 @@ set(LLVM_LINK_COMPONENTS
)
add_tablegen(mlir-tblgen MLIR
DialectGen.cpp
EnumsGen.cpp
LLVMIRConversionGen.cpp
LLVMIRIntrinsicGen.cpp

View File

@ -0,0 +1,166 @@
//===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// DialectGen uses the description of dialects to generate C++ definitions.
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/STLExtras.h"
#include "mlir/Support/StringExtras.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/OpClass.h"
#include "mlir/TableGen/OpInterfaces.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#define DEBUG_TYPE "mlir-tblgen-opdefgen"
using namespace mlir;
using namespace mlir::tblgen;
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
static llvm::cl::opt<std::string>
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
/// Given a set of records for a T, filter the ones that correspond to
/// the given dialect.
template <typename T>
static auto filterForDialect(ArrayRef<llvm::Record *> records,
Dialect &dialect) {
return llvm::make_filter_range(records, [&](const llvm::Record *record) {
return T(record).getDialect() == dialect;
});
}
//===----------------------------------------------------------------------===//
// GEN: Dialect declarations
//===----------------------------------------------------------------------===//
/// The code block for the start of a dialect class declaration.
///
/// {0}: The name of the dialect class.
/// {1}: The dialect namespace.
static const char *const dialectDeclBeginStr = R"(
class {0} : public ::mlir::Dialect {
public:
explicit {0}(::mlir::MLIRContext *context);
static ::llvm::StringRef getDialectNamespace() { return "{1}"; }
)";
/// The code block for the attribute parser/printer hooks.
static const char *const attrParserDecl = R"(
/// Parse an attribute registered to this dialect.
::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
::mlir::Type type) const override;
/// Print an attribute registered to this dialect.
void printAttribute(::mlir::Attribute attr,
::mlir::DialectAsmPrinter &os) const override;
)";
/// The code block for the type parser/printer hooks.
static const char *const typeParserDecl = R"(
/// Parse a type registered to this dialect.
::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
/// Print a type registered to this dialect.
void printType(::mlir::Type type,
::mlir::DialectAsmPrinter &os) const override;
)";
/// The code block for the constant materializer hook.
static const char *const constantMaterializerDecl = R"(
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
::mlir::Attribute value,
::mlir::Type type,
::mlir::Location loc) override;
)";
/// Generate the declaration for the given dialect class.
static void emitDialectDecl(
Dialect &dialect,
FunctionTraits<decltype(&filterForDialect<Attribute>)>::result_t
dialectAttrs,
FunctionTraits<decltype(&filterForDialect<Type>)>::result_t dialectTypes,
raw_ostream &os) {
// Emit the start of the decl.
std::string cppName = dialect.getCppClassName();
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName());
// Check for any attributes/types registered to this dialect. If there are,
// add the hooks for parsing/printing.
if (!dialectAttrs.empty())
os << attrParserDecl;
if (!dialectTypes.empty())
os << typeParserDecl;
// Add the decls for the various features of the dialect.
if (dialect.hasConstantMaterializer())
os << constantMaterializerDecl;
if (llvm::Optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
os << *extraDecl;
// End the dialect decl.
os << "};\n";
}
static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
emitSourceFileHeader("Dialect Declarations", os);
auto defs = recordKeeper.getAllDerivedDefinitions("Dialect");
if (defs.empty())
return false;
// Select the dialect to gen for.
const llvm::Record *dialectDef = nullptr;
if (defs.size() == 1 && selectedDialect.getNumOccurrences() == 0) {
dialectDef = defs.front();
} else if (selectedDialect.getNumOccurrences() == 0) {
llvm::errs() << "when more than 1 dialect is present, one must be selected "
"via '-dialect'";
return true;
} else {
auto dialectIt = llvm::find_if(defs, [](const llvm::Record *def) {
return Dialect(def).getName() == selectedDialect;
});
if (dialectIt == defs.end()) {
llvm::errs() << "selected dialect with '-dialect' does not exist";
return true;
}
dialectDef = *dialectIt;
}
auto attrDefs = recordKeeper.getAllDerivedDefinitions("DialectAttr");
auto typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType");
Dialect dialect(dialectDef);
emitDialectDecl(dialect, filterForDialect<Attribute>(attrDefs, dialect),
filterForDialect<Type>(typeDefs, dialect), os);
return false;
}
//===----------------------------------------------------------------------===//
// GEN: Dialect registration hooks
//===----------------------------------------------------------------------===//
static mlir::GenRegistration
genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
return emitDialectDecls(records, os);
});