forked from OSchip/llvm-project
[mlir] Add support for generating dialect declarations via tablegen.
Summary: This generates the class declarations for dialects using the existing 'Dialect' tablegen classes. Differential Revision: https://reviews.llvm.org/D76185
This commit is contained in:
parent
27f303924e
commit
429d792f23
|
@ -28,10 +28,11 @@ function(whole_archive_link target)
|
|||
endfunction(whole_archive_link)
|
||||
|
||||
# Declare a dialect in the include directory
|
||||
function(add_mlir_dialect dialect dialect_doc_filename)
|
||||
function(add_mlir_dialect dialect dialect_namespace dialect_doc_filename)
|
||||
set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
|
||||
mlir_tablegen(${dialect}.h.inc -gen-op-decls)
|
||||
mlir_tablegen(${dialect}.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace})
|
||||
add_public_tablegen_target(MLIR${dialect}IncGen)
|
||||
add_dependencies(mlir-headers MLIR${dialect}IncGen)
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ is declared using add_mlir_dialect().
|
|||
|
||||
```cmake
|
||||
|
||||
add_mlir_dialect(FooOps FooOps)
|
||||
add_mlir_dialect(FooOps foo FooOps)
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -36,17 +36,6 @@ class OpBuilder;
|
|||
/// symbol.
|
||||
bool isTopLevelValue(Value value);
|
||||
|
||||
class AffineOpsDialect : public Dialect {
|
||||
public:
|
||||
AffineOpsDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "affine"; }
|
||||
|
||||
/// Materialize a single constant operation from a given attribute value with
|
||||
/// the desired resultant type.
|
||||
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||
Location loc) override;
|
||||
};
|
||||
|
||||
/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
|
||||
/// from a source memref to a destination memref. The source and destination
|
||||
/// memref need not be of the same dimensionality, but need to have the same
|
||||
|
@ -504,6 +493,8 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
|
|||
void fullyComposeAffineMapAndOperands(AffineMap *map,
|
||||
SmallVectorImpl<Value> *operands);
|
||||
|
||||
#include "mlir/Dialect/AffineOps/AffineOpsDialect.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h.inc"
|
||||
|
||||
|
|
|
@ -17,14 +17,15 @@ include "mlir/Dialect/AffineOps/AffineOpsBase.td"
|
|||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
|
||||
def Affine_Dialect : Dialect {
|
||||
def AffineOps_Dialect : Dialect {
|
||||
let name = "affine";
|
||||
let cppNamespace = "";
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
// Base class for Affine dialect ops.
|
||||
class Affine_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Affine_Dialect, mnemonic, traits> {
|
||||
Op<AffineOps_Dialect, mnemonic, traits> {
|
||||
// For every affine op, there needs to be a:
|
||||
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
|
||||
// * LogicalResult verify(${C++ class of Op} op)
|
||||
|
@ -290,7 +291,7 @@ def AffineIfOp : Affine_Op<"if",
|
|||
}
|
||||
|
||||
class AffineMinMaxOpBase<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Affine_Dialect, mnemonic, traits> {
|
||||
Op<AffineOps_Dialect, mnemonic, traits> {
|
||||
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$operands);
|
||||
let results = (outs Index);
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
add_mlir_dialect(AffineOps AffineOps)
|
||||
add_mlir_dialect(AffineOps affine AffineOps)
|
||||
|
|
|
@ -1 +1 @@
|
|||
add_mlir_dialect(FxpMathOps FxpMathOps)
|
||||
add_mlir_dialect(FxpMathOps fxpmath FxpMathOps)
|
||||
|
|
|
@ -17,11 +17,7 @@
|
|||
namespace mlir {
|
||||
namespace fxpmath {
|
||||
|
||||
/// Defines the 'FxpMathOps' dialect.
|
||||
class FxpMathOpsDialect : public Dialect {
|
||||
public:
|
||||
FxpMathOpsDialect(MLIRContext *context);
|
||||
};
|
||||
#include "mlir/Dialect/FxpMathOps/FxpMathOpsDialect.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h.inc"
|
||||
|
|
|
@ -15,10 +15,10 @@
|
|||
#define DIALECT_FXPMATHOPS_FXPMATH_OPS_
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/QuantOps/QuantPredicates.td"
|
||||
include "mlir/Dialect/QuantOps/QuantOpsBase.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
|
||||
def fxpmath_Dialect : Dialect {
|
||||
def FxpMathOps_Dialect : Dialect {
|
||||
let name = "fxpmath";
|
||||
}
|
||||
|
||||
|
@ -78,7 +78,7 @@ def fxpmath_CompareFnAttr : StrEnumAttr<"ComparisonFn",
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class fxpmath_Op<string mnemonic, list<OpTrait> traits> :
|
||||
Op<fxpmath_Dialect, mnemonic, traits>;
|
||||
Op<FxpMathOps_Dialect, mnemonic, traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Fixed-point (fxp) arithmetic ops used by kernels.
|
||||
|
|
|
@ -1 +1 @@
|
|||
add_mlir_dialect(GPUOps GPUOps)
|
||||
add_mlir_dialect(GPUOps gpu GPUOps)
|
||||
|
|
|
@ -26,51 +26,6 @@ class FuncOp;
|
|||
|
||||
namespace gpu {
|
||||
|
||||
/// The dialect containing GPU kernel launching operations and related
|
||||
/// facilities.
|
||||
class GPUDialect : public Dialect {
|
||||
public:
|
||||
/// Create the dialect in the given `context`.
|
||||
explicit GPUDialect(MLIRContext *context);
|
||||
/// Get dialect namespace.
|
||||
static StringRef getDialectNamespace() { return "gpu"; }
|
||||
|
||||
/// Get the name of the attribute used to annotate the modules that contain
|
||||
/// kernel modules.
|
||||
static StringRef getContainerModuleAttrName() {
|
||||
return "gpu.container_module";
|
||||
}
|
||||
|
||||
/// Get the canonical string name of the dialect.
|
||||
static StringRef getDialectName();
|
||||
|
||||
/// Get the name of the attribute used to annotate external kernel functions.
|
||||
static StringRef getKernelFuncAttrName() { return "gpu.kernel"; }
|
||||
|
||||
/// Get the name of the attribute used to annotate kernel modules.
|
||||
static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; }
|
||||
|
||||
/// Returns whether the given function is a kernel function, i.e., has the
|
||||
/// 'gpu.kernel' attribute.
|
||||
static bool isKernel(Operation *op);
|
||||
|
||||
/// Returns the number of workgroup (thread, block) dimensions supported in
|
||||
/// the GPU dialect.
|
||||
// TODO(zinenko,herhut): consider generalizing this.
|
||||
static unsigned getNumWorkgroupDimensions() { return 3; }
|
||||
|
||||
/// Returns the numeric value used to identify the workgroup memory address
|
||||
/// space.
|
||||
static unsigned getWorkgroupAddressSpace() { return 3; }
|
||||
|
||||
/// Returns the numeric value used to identify the private memory address
|
||||
/// space.
|
||||
static unsigned getPrivateAddressSpace() { return 5; }
|
||||
|
||||
LogicalResult verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) override;
|
||||
};
|
||||
|
||||
/// Utility class for the GPU dialect to represent triples of `Value`s
|
||||
/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
|
||||
struct KernelDim3 {
|
||||
|
@ -79,6 +34,8 @@ struct KernelDim3 {
|
|||
Value z;
|
||||
};
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUOpsDialect.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/GPU/GPUOps.h.inc"
|
||||
|
||||
|
|
|
@ -28,6 +28,39 @@ def IntLikeOrLLVMInt : TypeConstraint<
|
|||
|
||||
def GPU_Dialect : Dialect {
|
||||
let name = "gpu";
|
||||
let extraClassDeclaration = [{
|
||||
/// Get the name of the attribute used to annotate the modules that contain
|
||||
/// kernel modules.
|
||||
static StringRef getContainerModuleAttrName() {
|
||||
return "gpu.container_module";
|
||||
}
|
||||
/// Get the name of the attribute used to annotate external kernel
|
||||
/// functions.
|
||||
static StringRef getKernelFuncAttrName() { return "gpu.kernel"; }
|
||||
|
||||
/// Get the name of the attribute used to annotate kernel modules.
|
||||
static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; }
|
||||
|
||||
/// Returns whether the given function is a kernel function, i.e., has the
|
||||
/// 'gpu.kernel' attribute.
|
||||
static bool isKernel(Operation *op);
|
||||
|
||||
/// Returns the number of workgroup (thread, block) dimensions supported in
|
||||
/// the GPU dialect.
|
||||
// TODO(zinenko,herhut): consider generalizing this.
|
||||
static unsigned getNumWorkgroupDimensions() { return 3; }
|
||||
|
||||
/// Returns the numeric value used to identify the workgroup memory address
|
||||
/// space.
|
||||
static unsigned getWorkgroupAddressSpace() { return 3; }
|
||||
|
||||
/// Returns the numeric value used to identify the private memory address
|
||||
/// space.
|
||||
static unsigned getPrivateAddressSpace() { return 5; }
|
||||
|
||||
LogicalResult verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) override;
|
||||
}];
|
||||
}
|
||||
|
||||
class GPU_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
|
||||
mlir_tablegen(LLVMOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(LLVMOpsDialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(MLIRLLVMOpsIncGen)
|
||||
|
||||
add_mlir_dialect(NVVMOps NVVMOps)
|
||||
add_mlir_dialect(ROCDLOps ROCDLOps)
|
||||
add_mlir_dialect(NVVMOps nvvm NVVMOps)
|
||||
add_mlir_dialect(ROCDLOps rocdl ROCDLOps)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
|
||||
mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
|
||||
|
|
|
@ -201,32 +201,7 @@ private:
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/LLVMIR/LLVMOps.h.inc"
|
||||
|
||||
class LLVMDialect : public Dialect {
|
||||
public:
|
||||
explicit LLVMDialect(MLIRContext *context);
|
||||
~LLVMDialect();
|
||||
static StringRef getDialectNamespace() { return "llvm"; }
|
||||
|
||||
llvm::LLVMContext &getLLVMContext();
|
||||
llvm::Module &getLLVMModule();
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
Type parseType(DialectAsmParser &parser) const override;
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void printType(Type type, DialectAsmPrinter &os) const override;
|
||||
|
||||
/// Verify a region argument attribute registered to this dialect.
|
||||
/// Returns failure if the verification failed, success otherwise.
|
||||
LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx,
|
||||
unsigned argIdx,
|
||||
NamedAttribute argAttr) override;
|
||||
|
||||
private:
|
||||
friend LLVMType;
|
||||
|
||||
std::unique_ptr<detail::LLVMDialectImpl> impl;
|
||||
};
|
||||
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.h.inc"
|
||||
|
||||
/// Create an LLVM global containing the string "value" at the module containing
|
||||
/// surrounding the insertion point of builder. Obtain the address of that
|
||||
|
|
|
@ -19,11 +19,28 @@ include "mlir/IR/OpBase.td"
|
|||
def LLVM_Dialect : Dialect {
|
||||
let name = "llvm";
|
||||
let cppNamespace = "LLVM";
|
||||
let extraClassDeclaration = [{
|
||||
~LLVMDialect();
|
||||
llvm::LLVMContext &getLLVMContext();
|
||||
llvm::Module &getLLVMModule();
|
||||
|
||||
/// Verify a region argument attribute registered to this dialect.
|
||||
/// Returns failure if the verification failed, success otherwise.
|
||||
LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx,
|
||||
unsigned argIdx,
|
||||
NamedAttribute argAttr) override;
|
||||
|
||||
private:
|
||||
friend LLVMType;
|
||||
|
||||
std::unique_ptr<detail::LLVMDialectImpl> impl;
|
||||
}];
|
||||
}
|
||||
|
||||
// LLVM IR type wrapped in MLIR.
|
||||
def LLVM_Type : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
|
||||
"LLVM dialect type">;
|
||||
def LLVM_Type : DialectType<LLVM_Dialect,
|
||||
CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
|
||||
"LLVM dialect type">;
|
||||
|
||||
// Type constraint accepting only wrapped LLVM integer types.
|
||||
def LLVMInt : TypeConstraint<
|
||||
|
|
|
@ -25,12 +25,7 @@ namespace NVVM {
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/LLVMIR/NVVMOps.h.inc"
|
||||
|
||||
class NVVMDialect : public Dialect {
|
||||
public:
|
||||
explicit NVVMDialect(MLIRContext *context);
|
||||
|
||||
static StringRef getDialectNamespace() { return "nvvm"; }
|
||||
};
|
||||
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.h.inc"
|
||||
|
||||
} // namespace NVVM
|
||||
} // namespace mlir
|
||||
|
|
|
@ -33,12 +33,7 @@ namespace ROCDL {
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLOps.h.inc"
|
||||
|
||||
class ROCDLDialect : public Dialect {
|
||||
public:
|
||||
explicit ROCDLDialect(MLIRContext *context);
|
||||
|
||||
static StringRef getDialectNamespace() { return "rocdl"; }
|
||||
};
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLOpsDialect.h.inc"
|
||||
|
||||
} // namespace ROCDL
|
||||
} // namespace mlir
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
add_mlir_dialect(LinalgOps LinalgDoc)
|
||||
add_mlir_dialect(LinalgOps linalg LinalgDoc)
|
||||
set(LLVM_TARGET_DEFINITIONS LinalgStructuredOps.td)
|
||||
mlir_tablegen(LinalgStructuredOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(LinalgStructuredOps.cpp.inc -gen-op-defs)
|
||||
|
|
|
@ -34,6 +34,6 @@ def Linalg_Dialect : Dialect {
|
|||
|
||||
// Whether a type is a RangeType.
|
||||
def LinalgIsRangeTypePred : CPred<"$_self.isa<RangeType>()">;
|
||||
def Range : Type<LinalgIsRangeTypePred, "range">;
|
||||
def Range : DialectType<Linalg_Dialect, LinalgIsRangeTypePred, "range">;
|
||||
|
||||
#endif // LINALG_BASE
|
||||
|
|
|
@ -21,17 +21,7 @@ enum LinalgTypes {
|
|||
LAST_USED_LINALG_TYPE = Range,
|
||||
};
|
||||
|
||||
class LinalgDialect : public Dialect {
|
||||
public:
|
||||
explicit LinalgDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "linalg"; }
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
Type parseType(DialectAsmParser &parser) const override;
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void printType(Type type, DialectAsmPrinter &os) const override;
|
||||
};
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
|
||||
|
||||
/// A RangeType represents a minimal range abstraction (min, max, step).
|
||||
/// It is constructed by calling the linalg.range op with three values index of
|
||||
|
|
|
@ -1 +1 @@
|
|||
add_mlir_dialect(LoopOps LoopOps)
|
||||
add_mlir_dialect(LoopOps loop LoopOps)
|
||||
|
|
|
@ -25,11 +25,7 @@ namespace loop {
|
|||
|
||||
class TerminatorOp;
|
||||
|
||||
class LoopOpsDialect : public Dialect {
|
||||
public:
|
||||
LoopOpsDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "loop"; }
|
||||
};
|
||||
#include "mlir/Dialect/LoopOps/LoopOpsDialect.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/LoopOps/LoopOps.h.inc"
|
||||
|
|
|
@ -16,14 +16,14 @@
|
|||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
|
||||
def Loop_Dialect : Dialect {
|
||||
def LoopOps_Dialect : Dialect {
|
||||
let name = "loop";
|
||||
let cppNamespace = "";
|
||||
}
|
||||
|
||||
// Base class for Loop dialect ops.
|
||||
class Loop_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Loop_Dialect, mnemonic, traits> {
|
||||
Op<LoopOps_Dialect, mnemonic, traits> {
|
||||
// For every standard op, there needs to be a:
|
||||
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
|
||||
// * LogicalResult verify(${C++ class of Op} op)
|
||||
|
|
|
@ -1 +1 @@
|
|||
add_mlir_dialect(OpenMPOps OpenMPOps)
|
||||
add_mlir_dialect(OpenMPOps omp OpenMPOps)
|
||||
|
|
|
@ -22,13 +22,7 @@ namespace omp {
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/OpenMP/OpenMPOps.h.inc"
|
||||
|
||||
class OpenMPDialect : public Dialect {
|
||||
public:
|
||||
explicit OpenMPDialect(MLIRContext *context);
|
||||
|
||||
static StringRef getDialectNamespace() { return "omp"; }
|
||||
};
|
||||
|
||||
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc"
|
||||
} // namespace omp
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
add_mlir_dialect(QuantOps QuantOps)
|
||||
add_mlir_dialect(QuantOps quant QuantOps)
|
||||
|
|
|
@ -21,17 +21,7 @@
|
|||
namespace mlir {
|
||||
namespace quant {
|
||||
|
||||
/// Defines the 'Quantization' dialect
|
||||
class QuantizationDialect : public Dialect {
|
||||
public:
|
||||
QuantizationDialect(MLIRContext *context);
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
Type parseType(DialectAsmParser &parser) const override;
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void printType(Type type, DialectAsmPrinter &os) const override;
|
||||
};
|
||||
#include "mlir/Dialect/QuantOps/QuantOpsDialect.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h.inc"
|
||||
|
|
|
@ -13,20 +13,15 @@
|
|||
#ifndef DIALECT_QUANTOPS_QUANT_OPS_
|
||||
#define DIALECT_QUANTOPS_QUANT_OPS_
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/QuantOps/QuantPredicates.td"
|
||||
include "mlir/Dialect/QuantOps/QuantOpsBase.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
|
||||
def quant_Dialect : Dialect {
|
||||
let name = "quant";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Base classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class quant_Op<string mnemonic, list<OpTrait> traits> :
|
||||
Op<quant_Dialect, mnemonic, traits>;
|
||||
Op<Quantization_Dialect, mnemonic, traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Quantization casts
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- QuantPredicates.td - Predicates for dialect types ---*- tablegen -*-===//
|
||||
//===- QuantOpsBase.td - Quantization dialect base ---------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -10,8 +10,14 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef DIALECT_QUANTOPS_QUANT_PREDICATES_
|
||||
#define DIALECT_QUANTOPS_QUANT_PREDICATES_
|
||||
#ifndef DIALECT_QUANTOPS_QUANT_OPS_BASE_
|
||||
#define DIALECT_QUANTOPS_QUANT_OPS_BASE_
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Quantization_Dialect : Dialect {
|
||||
let name = "quant";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Quantization type definitions
|
||||
|
@ -54,10 +60,12 @@ def quant_RealOrStorageValueType :
|
|||
|
||||
// An implementation of UniformQuantizedType.
|
||||
def quant_UniformQuantizedType :
|
||||
Type<CPred<"$_self.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
|
||||
DialectType<Quantization_Dialect,
|
||||
CPred<"$_self.isa<UniformQuantizedType>()">,
|
||||
"UniformQuantizedType">;
|
||||
|
||||
// Predicate for detecting a container or primitive of UniformQuantizedType.
|
||||
def quant_UniformQuantizedValueType :
|
||||
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
|
||||
|
||||
#endif // DIALECT_QUANTOPS_QUANT_PREDICATES_
|
||||
#endif // DIALECT_QUANTOPS_QUANT_OPS_BASE_
|
|
@ -1,4 +1,4 @@
|
|||
add_mlir_dialect(SPIRVOps SPIRVOps)
|
||||
add_mlir_dialect(SPIRVOps spv SPIRVOps)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
|
||||
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
|
||||
|
|
|
@ -22,7 +22,7 @@ include "mlir/Dialect/SPIRV/SPIRVAvailability.td"
|
|||
// SPIR-V dialect definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SPV_Dialect : Dialect {
|
||||
def SPIRV_Dialect : Dialect {
|
||||
let name = "spv";
|
||||
|
||||
let summary = "The SPIR-V dialect in MLIR.";
|
||||
|
@ -46,6 +46,43 @@ def SPV_Dialect : Dialect {
|
|||
}];
|
||||
|
||||
let cppNamespace = "spirv";
|
||||
let hasConstantMaterializer = 1;
|
||||
let extraClassDeclaration = [{
|
||||
//===------------------------------------------------------------------===//
|
||||
// Type
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Checks if the given `type` is valid in SPIR-V dialect.
|
||||
static bool isValidType(Type type);
|
||||
|
||||
/// Checks if the given `scalar type` is valid in SPIR-V dialect.
|
||||
static bool isValidScalarType(Type type);
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// Attribute
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the attribute name to use when specifying decorations on results
|
||||
/// of operations.
|
||||
static std::string getAttributeName(Decoration decoration);
|
||||
|
||||
/// Provides a hook for verifying SPIR-V dialect attributes attached to the
|
||||
/// given op.
|
||||
LogicalResult verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attribute) override;
|
||||
|
||||
/// Provides a hook for verifying SPIR-V dialect attributes attached to the
|
||||
/// given op's region argument.
|
||||
LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIndex,
|
||||
unsigned argIndex,
|
||||
NamedAttribute attribute) override;
|
||||
|
||||
/// Provides a hook for verifying SPIR-V dialect attributes attached to the
|
||||
/// given op's region result.
|
||||
LogicalResult verifyRegionResultAttribute(
|
||||
Operation *op, unsigned regionIndex, unsigned resultIndex,
|
||||
NamedAttribute attribute) override;
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2953,7 +2990,8 @@ def SPV_SamplerUseAttr:
|
|||
// SPIR-V attribute definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SPV_VerCapExtAttr : Attr<
|
||||
def SPV_VerCapExtAttr : DialectAttr<
|
||||
SPIRV_Dialect,
|
||||
CPred<"$_self.isa<::mlir::spirv::VerCapExtAttr>()">,
|
||||
"version-capability-extension attribute"> {
|
||||
let storageType = "::mlir::spirv::VerCapExtAttr";
|
||||
|
@ -2993,10 +3031,14 @@ def SPV_Vector : VectorOfLengthAndType<[2, 3, 4],
|
|||
[SPV_Bool, SPV_Integer, SPV_Float]>;
|
||||
// Component type check is done in the type parser for the following SPIR-V
|
||||
// dialect-specific types so we use "Any" here.
|
||||
def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">;
|
||||
def SPV_AnyArray : Type<SPV_IsArrayType, "any SPIR-V array type">;
|
||||
def SPV_AnyRTArray : Type<SPV_IsRTArrayType, "any SPIR-V runtime array type">;
|
||||
def SPV_AnyStruct : Type<SPV_IsStructType, "any SPIR-V struct type">;
|
||||
def SPV_AnyPtr : DialectType<SPIRV_Dialect, SPV_IsPtrType,
|
||||
"any SPIR-V pointer type">;
|
||||
def SPV_AnyArray : DialectType<SPIRV_Dialect, SPV_IsArrayType,
|
||||
"any SPIR-V array type">;
|
||||
def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
|
||||
"any SPIR-V runtime array type">;
|
||||
def SPV_AnyStruct : DialectType<SPIRV_Dialect, SPV_IsStructType,
|
||||
"any SPIR-V struct type">;
|
||||
|
||||
def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
|
||||
def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
|
||||
|
@ -3264,7 +3306,7 @@ def SPV_OpcodeAttr :
|
|||
|
||||
// Base class for all SPIR-V ops.
|
||||
class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<SPV_Dialect, mnemonic, !listconcat(traits, [
|
||||
Op<SPIRV_Dialect, mnemonic, !listconcat(traits, [
|
||||
// TODO(antiagainst): We don't need all of the following traits for
|
||||
// every op; only the suitabble ones should be added automatically
|
||||
// after ODS supports dialect-specific contents.
|
||||
|
|
|
@ -20,67 +20,7 @@ namespace spirv {
|
|||
|
||||
enum class Decoration : uint32_t;
|
||||
|
||||
class SPIRVDialect : public Dialect {
|
||||
public:
|
||||
explicit SPIRVDialect(MLIRContext *context);
|
||||
|
||||
static StringRef getDialectNamespace() { return "spv"; }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Type
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Checks if the given `type` is valid in SPIR-V dialect.
|
||||
static bool isValidType(Type type);
|
||||
|
||||
/// Checks if the given `scalar type` is valid in SPIR-V dialect.
|
||||
static bool isValidScalarType(Type type);
|
||||
|
||||
/// Parses a type registered to this dialect.
|
||||
Type parseType(DialectAsmParser &parser) const override;
|
||||
|
||||
/// Prints a type registered to this dialect.
|
||||
void printType(Type type, DialectAsmPrinter &os) const override;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Attribute
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the attribute name to use when specifying decorations on results
|
||||
/// of operations.
|
||||
static std::string getAttributeName(Decoration decoration);
|
||||
|
||||
/// Parses an attribute registered to this dialect.
|
||||
Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
|
||||
|
||||
/// Prints an attribute registered to this dialect.
|
||||
void printAttribute(Attribute, DialectAsmPrinter &printer) const override;
|
||||
|
||||
/// Provides a hook for verifying SPIR-V dialect attributes attached to the
|
||||
/// given op.
|
||||
LogicalResult verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attribute) override;
|
||||
|
||||
/// Provides a hook for verifying SPIR-V dialect attributes attached to the
|
||||
/// given op's region argument.
|
||||
LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIndex,
|
||||
unsigned argIndex,
|
||||
NamedAttribute attribute) override;
|
||||
|
||||
/// Provides a hook for verifying SPIR-V dialect attributes attached to the
|
||||
/// given op's region result.
|
||||
LogicalResult verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
|
||||
unsigned resultIndex,
|
||||
NamedAttribute attribute) override;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Constant
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Provides a hook for materializing a constant to this dialect.
|
||||
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||
Location loc) override;
|
||||
};
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOpsDialect.h.inc"
|
||||
|
||||
} // end namespace spirv
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -29,7 +29,7 @@ include "mlir/Dialect/SPIRV/SPIRVBase.td"
|
|||
// 1) Descriptor Set.
|
||||
// 2) Binding number.
|
||||
// 3) Storage class.
|
||||
def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPV_Dialect, [
|
||||
def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPIRV_Dialect, [
|
||||
StructFieldAttr<"descriptor_set", I32Attr>,
|
||||
StructFieldAttr<"binding", I32Attr>,
|
||||
StructFieldAttr<"storage_class", SPV_StorageClassAttr>
|
||||
|
@ -38,7 +38,7 @@ def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPV_Dialect, [
|
|||
// For entry functions, this attribute specifies information related to entry
|
||||
// points in the generated SPIR-V module:
|
||||
// 1) WorkGroup Size.
|
||||
def SPV_EntryPointABIAttr : StructAttr<"EntryPointABIAttr", SPV_Dialect, [
|
||||
def SPV_EntryPointABIAttr : StructAttr<"EntryPointABIAttr", SPIRV_Dialect, [
|
||||
StructFieldAttr<"local_size", I32ElementsAttr>
|
||||
]>;
|
||||
|
||||
|
@ -54,7 +54,7 @@ def SPV_CapabilityArrayAttr : TypedArrayAttrBase<
|
|||
// See https://renderdoc.org/vkspec_chunked/chap36.html#limits for the complete
|
||||
// list of limits and their explanation for the Vulkan API. The following ones
|
||||
// are those affecting SPIR-V CodeGen.
|
||||
def SPV_ResourceLimitsAttr : StructAttr<"ResourceLimitsAttr", SPV_Dialect, [
|
||||
def SPV_ResourceLimitsAttr : StructAttr<"ResourceLimitsAttr", SPIRV_Dialect, [
|
||||
StructFieldAttr<"max_compute_workgroup_invocations", I32Attr>,
|
||||
StructFieldAttr<"max_compute_workgroup_size", I32ElementsAttr>
|
||||
]>;
|
||||
|
|
|
@ -1 +1 @@
|
|||
add_mlir_dialect(ShapeOps ShapeOps)
|
||||
add_mlir_dialect(ShapeOps shape ShapeOps)
|
||||
|
|
|
@ -21,13 +21,6 @@
|
|||
namespace mlir {
|
||||
namespace shape {
|
||||
|
||||
/// This dialect contains shape inference related operations and facilities.
|
||||
class ShapeDialect : public Dialect {
|
||||
public:
|
||||
/// Create the dialect in the given `context`.
|
||||
explicit ShapeDialect(MLIRContext *context);
|
||||
};
|
||||
|
||||
namespace ShapeTypes {
|
||||
enum Kind {
|
||||
Component = Type::FIRST_SHAPE_TYPE,
|
||||
|
@ -112,6 +105,8 @@ public:
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Shape/IR/ShapeOps.h.inc"
|
||||
|
||||
#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.h.inc"
|
||||
|
||||
} // namespace shape
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Ops.td)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(OpsDialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(MLIRStandardOpsIncGen)
|
||||
|
|
|
@ -31,20 +31,11 @@ class Builder;
|
|||
class FuncOp;
|
||||
class OpBuilder;
|
||||
|
||||
class StandardOpsDialect : public Dialect {
|
||||
public:
|
||||
StandardOpsDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "std"; }
|
||||
|
||||
/// Materialize a single constant operation from a given attribute value with
|
||||
/// the desired resultant type.
|
||||
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||
Location loc) override;
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h.inc"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc"
|
||||
|
||||
/// This is a refinement of the "constant" op for the case where it is
|
||||
/// returning a float value of FloatType.
|
||||
///
|
||||
|
|
|
@ -18,14 +18,15 @@ include "mlir/Interfaces/CallInterfaces.td"
|
|||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
|
||||
def Std_Dialect : Dialect {
|
||||
def StandardOps_Dialect : Dialect {
|
||||
let name = "std";
|
||||
let cppNamespace = "";
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
// Base class for Standard dialect ops.
|
||||
class Std_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Std_Dialect, mnemonic, traits> {
|
||||
Op<StandardOps_Dialect, mnemonic, traits> {
|
||||
// For every standard op, there needs to be a:
|
||||
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
|
||||
// * LogicalResult verify(${C++ class of Op} op)
|
||||
|
@ -63,7 +64,7 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
// Base class for unary ops. Requires single operand and result. Individual
|
||||
// classes will have `operand` accessor.
|
||||
class UnaryOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Std_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
|
||||
Op<StandardOps_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
|
||||
let results = (outs AnyType);
|
||||
let printer = [{
|
||||
return printStandardUnaryOp(this->getOperation(), p);
|
||||
|
@ -86,7 +87,7 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
|
|||
// results to be of the same type, but does not constrain them to specific
|
||||
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
|
||||
class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Std_Dialect, mnemonic,
|
||||
Op<StandardOps_Dialect, mnemonic,
|
||||
!listconcat(traits, [NoSideEffect, SameOperandsAndResultType])> {
|
||||
|
||||
let results = (outs AnyType);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
add_mlir_dialect(VectorOps VectorOps)
|
||||
add_mlir_dialect(VectorOps vector VectorOps)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS VectorTransformPatterns.td)
|
||||
mlir_tablegen(VectorTransformPatterns.h.inc -gen-rewriters)
|
||||
|
|
|
@ -24,18 +24,6 @@ class MLIRContext;
|
|||
class OwningRewritePatternList;
|
||||
namespace vector {
|
||||
|
||||
/// Dialect for Ops on higher-dimensional vector types.
|
||||
class VectorOpsDialect : public Dialect {
|
||||
public:
|
||||
VectorOpsDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "vector"; }
|
||||
|
||||
/// Materialize a single constant operation from a given attribute value with
|
||||
/// the desired resultant type.
|
||||
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||
Location loc) override;
|
||||
};
|
||||
|
||||
/// Collect a set of vector-to-vector canonicalization patterns.
|
||||
void populateVectorToVectorCanonicalizationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context);
|
||||
|
@ -75,6 +63,8 @@ ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/VectorOps/VectorOps.h.inc"
|
||||
|
||||
#include "mlir/Dialect/VectorOps/VectorOpsDialect.h.inc"
|
||||
|
||||
} // end namespace vector
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -16,14 +16,15 @@
|
|||
include "mlir/Dialect/AffineOps/AffineOpsBase.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
|
||||
def Vector_Dialect : Dialect {
|
||||
def VectorOps_Dialect : Dialect {
|
||||
let name = "vector";
|
||||
let cppNamespace = "vector";
|
||||
let hasConstantMaterializer = 1;
|
||||
}
|
||||
|
||||
// Base class for Vector dialect ops.
|
||||
class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Vector_Dialect, mnemonic, traits> {
|
||||
Op<VectorOps_Dialect, mnemonic, traits> {
|
||||
// For every vector op, there needs to be a:
|
||||
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
|
||||
// * LogicalResult verify(${C++ class of Op} op)
|
||||
|
@ -432,7 +433,7 @@ def Vector_ExtractSlicesOp :
|
|||
}
|
||||
|
||||
def Vector_FMAOp :
|
||||
Op<Vector_Dialect, "fma", [NoSideEffect,
|
||||
Op<VectorOps_Dialect, "fma", [NoSideEffect,
|
||||
AllTypesMatch<["lhs", "rhs", "acc", "result"]>]>,
|
||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>,
|
||||
Results<(outs AnyVector:$result)> {
|
||||
|
|
|
@ -253,6 +253,13 @@ class Dialect {
|
|||
// the generated files are included into the dialect, you may want to specify
|
||||
// a full namespace path or a partial one.
|
||||
string cppNamespace = name;
|
||||
|
||||
// An optional code block containing extra declarations to place in the
|
||||
// dialect declaration.
|
||||
code extraClassDeclaration = "";
|
||||
|
||||
// If this dialect overrides the hook for materializing constants.
|
||||
bit hasConstantMaterializer = 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -753,6 +760,12 @@ class Attr<Pred condition, string descr = ""> :
|
|||
Attr baseAttr = ?;
|
||||
}
|
||||
|
||||
// An attribute of a specific dialect.
|
||||
class DialectAttr<Dialect d, Pred condition, string descr = ""> :
|
||||
Attr<condition, descr> {
|
||||
Dialect dialect = d;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attribute modifier definition
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ class Record;
|
|||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
class Dialect;
|
||||
class Type;
|
||||
|
||||
// Wrapper class with helper methods for accessing attribute constraints defined
|
||||
|
@ -105,6 +106,9 @@ public:
|
|||
// Returns the code body for derived attribute. Aborts if this is not a
|
||||
// derived attribute.
|
||||
StringRef getDerivedCodeBody() const;
|
||||
|
||||
// Returns the dialect for the attribute if defined.
|
||||
Dialect getDialect() const;
|
||||
};
|
||||
|
||||
// Wrapper class providing helper methods for accessing MLIR constant attribute
|
||||
|
|
|
@ -32,6 +32,9 @@ public:
|
|||
// Returns the C++ namespaces that ops of this dialect should be placed into.
|
||||
StringRef getCppNamespace() const;
|
||||
|
||||
// Returns this dialect's C++ class name.
|
||||
std::string getCppClassName() const;
|
||||
|
||||
// Returns the summary description of the dialect. Returns empty string if
|
||||
// none.
|
||||
StringRef getSummary() const;
|
||||
|
@ -39,6 +42,12 @@ public:
|
|||
// Returns the description of the dialect. Returns empty string if none.
|
||||
StringRef getDescription() const;
|
||||
|
||||
// Returns the dialects extra class declaration code.
|
||||
llvm::Optional<StringRef> getExtraClassDeclaration() const;
|
||||
|
||||
// Returns if this dialect has a constant materializer or not.
|
||||
bool hasConstantMaterializer() const;
|
||||
|
||||
// Returns whether two dialects are equal by checking the equality of the
|
||||
// underlying record.
|
||||
bool operator==(const Dialect &other) const;
|
||||
|
|
|
@ -28,15 +28,13 @@ using namespace mlir::gpu;
|
|||
// GPUDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
StringRef GPUDialect::getDialectName() { return "gpu"; }
|
||||
|
||||
bool GPUDialect::isKernel(Operation *op) {
|
||||
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
|
||||
return static_cast<bool>(isKernelAttr);
|
||||
}
|
||||
|
||||
GPUDialect::GPUDialect(MLIRContext *context)
|
||||
: Dialect(getDialectName(), context) {
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
|
||||
|
|
|
@ -132,6 +132,10 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const {
|
|||
return def->getValueAsString("body");
|
||||
}
|
||||
|
||||
tblgen::Dialect tblgen::Attribute::getDialect() const {
|
||||
return Dialect(def->getValueAsDef("dialect"));
|
||||
}
|
||||
|
||||
tblgen::ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
|
||||
assert(def->isSubClassOf("ConstantAttr") &&
|
||||
"must be subclass of TableGen 'ConstantAttr' class");
|
||||
|
|
|
@ -24,6 +24,13 @@ StringRef tblgen::Dialect::getCppNamespace() const {
|
|||
return def->getValueAsString("cppNamespace");
|
||||
}
|
||||
|
||||
std::string tblgen::Dialect::getCppClassName() const {
|
||||
// Simply use the name and remove any '_' tokens.
|
||||
std::string cppName = def->getName().str();
|
||||
llvm::erase_if(cppName, [](char c) { return c == '_'; });
|
||||
return cppName;
|
||||
}
|
||||
|
||||
static StringRef getAsStringOrEmpty(const llvm::Record &record,
|
||||
StringRef fieldName) {
|
||||
if (auto valueInit = record.getValueInit(fieldName)) {
|
||||
|
@ -42,6 +49,15 @@ StringRef tblgen::Dialect::getDescription() const {
|
|||
return getAsStringOrEmpty(*def, "description");
|
||||
}
|
||||
|
||||
llvm::Optional<StringRef> tblgen::Dialect::getExtraClassDeclaration() const {
|
||||
auto value = def->getValueAsString("extraClassDeclaration");
|
||||
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||
}
|
||||
|
||||
bool tblgen::Dialect::hasConstantMaterializer() const {
|
||||
return def->getValueAsBit("hasConstantMaterializer");
|
||||
}
|
||||
|
||||
bool Dialect::operator==(const Dialect &other) const {
|
||||
return def == other.def;
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ set(LLVM_LINK_COMPONENTS
|
|||
)
|
||||
|
||||
add_tablegen(mlir-tblgen MLIR
|
||||
DialectGen.cpp
|
||||
EnumsGen.cpp
|
||||
LLVMIRConversionGen.cpp
|
||||
LLVMIRIntrinsicGen.cpp
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
//===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// DialectGen uses the description of dialects to generate C++ definitions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "mlir/Support/StringExtras.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/OpClass.h"
|
||||
#include "mlir/TableGen/OpInterfaces.h"
|
||||
#include "mlir/TableGen/OpTrait.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
|
||||
#define DEBUG_TYPE "mlir-tblgen-opdefgen"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
|
||||
static llvm::cl::opt<std::string>
|
||||
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
|
||||
llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
|
||||
|
||||
/// Given a set of records for a T, filter the ones that correspond to
|
||||
/// the given dialect.
|
||||
template <typename T>
|
||||
static auto filterForDialect(ArrayRef<llvm::Record *> records,
|
||||
Dialect &dialect) {
|
||||
return llvm::make_filter_range(records, [&](const llvm::Record *record) {
|
||||
return T(record).getDialect() == dialect;
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GEN: Dialect declarations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// The code block for the start of a dialect class declaration.
|
||||
///
|
||||
/// {0}: The name of the dialect class.
|
||||
/// {1}: The dialect namespace.
|
||||
static const char *const dialectDeclBeginStr = R"(
|
||||
class {0} : public ::mlir::Dialect {
|
||||
public:
|
||||
explicit {0}(::mlir::MLIRContext *context);
|
||||
static ::llvm::StringRef getDialectNamespace() { return "{1}"; }
|
||||
)";
|
||||
|
||||
/// The code block for the attribute parser/printer hooks.
|
||||
static const char *const attrParserDecl = R"(
|
||||
/// Parse an attribute registered to this dialect.
|
||||
::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
|
||||
::mlir::Type type) const override;
|
||||
|
||||
/// Print an attribute registered to this dialect.
|
||||
void printAttribute(::mlir::Attribute attr,
|
||||
::mlir::DialectAsmPrinter &os) const override;
|
||||
)";
|
||||
|
||||
/// The code block for the type parser/printer hooks.
|
||||
static const char *const typeParserDecl = R"(
|
||||
/// Parse a type registered to this dialect.
|
||||
::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void printType(::mlir::Type type,
|
||||
::mlir::DialectAsmPrinter &os) const override;
|
||||
)";
|
||||
|
||||
/// The code block for the constant materializer hook.
|
||||
static const char *const constantMaterializerDecl = R"(
|
||||
/// Materialize a single constant operation from a given attribute value with
|
||||
/// the desired resultant type.
|
||||
::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
|
||||
::mlir::Attribute value,
|
||||
::mlir::Type type,
|
||||
::mlir::Location loc) override;
|
||||
)";
|
||||
|
||||
/// Generate the declaration for the given dialect class.
|
||||
static void emitDialectDecl(
|
||||
Dialect &dialect,
|
||||
FunctionTraits<decltype(&filterForDialect<Attribute>)>::result_t
|
||||
dialectAttrs,
|
||||
FunctionTraits<decltype(&filterForDialect<Type>)>::result_t dialectTypes,
|
||||
raw_ostream &os) {
|
||||
// Emit the start of the decl.
|
||||
std::string cppName = dialect.getCppClassName();
|
||||
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName());
|
||||
|
||||
// Check for any attributes/types registered to this dialect. If there are,
|
||||
// add the hooks for parsing/printing.
|
||||
if (!dialectAttrs.empty())
|
||||
os << attrParserDecl;
|
||||
if (!dialectTypes.empty())
|
||||
os << typeParserDecl;
|
||||
|
||||
// Add the decls for the various features of the dialect.
|
||||
if (dialect.hasConstantMaterializer())
|
||||
os << constantMaterializerDecl;
|
||||
if (llvm::Optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
|
||||
os << *extraDecl;
|
||||
|
||||
// End the dialect decl.
|
||||
os << "};\n";
|
||||
}
|
||||
|
||||
static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
emitSourceFileHeader("Dialect Declarations", os);
|
||||
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions("Dialect");
|
||||
if (defs.empty())
|
||||
return false;
|
||||
|
||||
// Select the dialect to gen for.
|
||||
const llvm::Record *dialectDef = nullptr;
|
||||
if (defs.size() == 1 && selectedDialect.getNumOccurrences() == 0) {
|
||||
dialectDef = defs.front();
|
||||
} else if (selectedDialect.getNumOccurrences() == 0) {
|
||||
llvm::errs() << "when more than 1 dialect is present, one must be selected "
|
||||
"via '-dialect'";
|
||||
return true;
|
||||
} else {
|
||||
auto dialectIt = llvm::find_if(defs, [](const llvm::Record *def) {
|
||||
return Dialect(def).getName() == selectedDialect;
|
||||
});
|
||||
if (dialectIt == defs.end()) {
|
||||
llvm::errs() << "selected dialect with '-dialect' does not exist";
|
||||
return true;
|
||||
}
|
||||
dialectDef = *dialectIt;
|
||||
}
|
||||
|
||||
auto attrDefs = recordKeeper.getAllDerivedDefinitions("DialectAttr");
|
||||
auto typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType");
|
||||
Dialect dialect(dialectDef);
|
||||
emitDialectDecl(dialect, filterForDialect<Attribute>(attrDefs, dialect),
|
||||
filterForDialect<Type>(typeDefs, dialect), os);
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GEN: Dialect registration hooks
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static mlir::GenRegistration
|
||||
genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
return emitDialectDecls(records, os);
|
||||
});
|
Loading…
Reference in New Issue