[mlir][shape] Switch types to ODS generated (NFC)

These were already pretty simple, so just switching to generated.
This commit is contained in:
Jacques Pienaar 2022-06-25 09:06:52 -07:00
parent e7bc73739a
commit 701051a8c2
4 changed files with 27 additions and 69 deletions

View File

@ -25,6 +25,9 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.h.inc"
namespace mlir {
class PatternRewriter;
@ -40,32 +43,6 @@ bool isExtentTensorType(Type);
// Given an input shape Value, try to obtain the shape's values.
LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues);
/// The shape descriptor type represents rank and dimension sizes.
class ShapeType : public Type::TypeBase<ShapeType, Type, TypeStorage> {
public:
using Base::Base;
};
/// The type of a single dimension.
class SizeType : public Type::TypeBase<SizeType, Type, TypeStorage> {
public:
using Base::Base;
};
/// The ValueShape represents a (potentially unknown) runtime value and shape.
class ValueShapeType
: public Type::TypeBase<ValueShapeType, Type, TypeStorage> {
public:
using Base::Base;
};
/// The Witness represents a runtime constraint, to be used as shape related
/// preconditions on code execution.
class WitnessType : public Type::TypeBase<WitnessType, Type, TypeStorage> {
public:
using Base::Base;
};
} // namespace shape
} // namespace mlir

View File

@ -13,6 +13,7 @@
#ifndef SHAPE_BASE_TD
#define SHAPE_BASE_TD
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
@ -43,9 +44,11 @@ def ShapeDialect : Dialect {
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}
def Shape_ShapeType : DialectType<ShapeDialect,
CPred<"$_self.isa<::mlir::shape::ShapeType>()">, "shape">,
BuildableType<"$_builder.getType<::mlir::shape::ShapeType>()"> {
class Shape_Type<string name, string typeMnemonic> : TypeDef<ShapeDialect, name> {
let mnemonic = typeMnemonic;
}
def Shape_ShapeType : Shape_Type<"Shape", "shape"> {
let description = [{
`shape.shape` represents either an unranked shape, a ranked shape with
possibly unknown dimensions or an invalid shape. The rank is of type
@ -62,9 +65,7 @@ def Shape_ShapeType : DialectType<ShapeDialect,
}];
}
def Shape_SizeType : DialectType<ShapeDialect,
CPred<"$_self.isa<::mlir::shape::SizeType>()">, "size">,
BuildableType<"$_builder.getType<::mlir::shape::SizeType>()"> {
def Shape_SizeType : Shape_Type<"Size", "size"> {
let description = [{
`shape.size` represents a non-negative integer with support for being
unknown and invalid.
@ -75,10 +76,7 @@ def Shape_SizeType : DialectType<ShapeDialect,
}];
}
def Shape_ValueShapeType : DialectType<ShapeDialect,
CPred<"$_self.isa<::mlir::shape::ValueShapeType>()">, "value shape">,
BuildableType<"::mlir::shape::ValueShapeType::get($_builder.getContext())">
{
def Shape_ValueShapeType : Shape_Type<"ValueShape", "value_shape"> {
let description = [{
`shape.value_shape` represents the value produced by an operation (this
corresponds to `Value` in the compiler) and a shape. Conceptually this is a
@ -112,9 +110,7 @@ def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType,
def Shape_SizeOrIndexType : AnyTypeOf<[Shape_SizeType, Index], "size or index">;
def Shape_WitnessType : DialectType<ShapeDialect,
CPred<"$_self.isa<::mlir::shape::WitnessType>()">, "witness">,
BuildableType<"$_builder.getType<::mlir::shape::WitnessType>()"> {
def Shape_WitnessType : Shape_Type<"Witness", "witness"> {
let description = [{
A witness is a structural device in the compiler to maintain ordering of
code relying on information obtained from passing assertions. Witnesses do

View File

@ -133,7 +133,10 @@ void ShapeDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
>();
addTypes<ShapeType, SizeType, ValueShapeType, WitnessType>();
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
>();
addInterfaces<ShapeInlinerInterface>();
// Allow unknown operations during prototyping and testing. As the dialect is
// still evolving it makes it simple to start with an unregistered ops and
@ -156,35 +159,6 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
return nullptr;
}
/// Parse a type registered to this dialect.
Type ShapeDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
if (keyword == "shape")
return ShapeType::get(getContext());
if (keyword == "size")
return SizeType::get(getContext());
if (keyword == "value_shape")
return ValueShapeType::get(getContext());
if (keyword == "witness")
return WitnessType::get(getContext());
parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
return Type();
}
/// Print a type registered to this dialect.
void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ShapeType>([&](Type) { os << "shape"; })
.Case<SizeType>([&](Type) { os << "size"; })
.Case<ValueShapeType>([&](Type) { os << "value_shape"; })
.Case<WitnessType>([&](Type) { os << "witness"; })
.Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
}
LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attribute) {
// Verify shape.lib attribute.
@ -1890,3 +1864,6 @@ void ReduceOp::print(OpAsmPrinter &p) {
#define GET_OP_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"

View File

@ -2837,6 +2837,14 @@ gentbl_cc_library(
["-gen-dialect-defs"],
"include/mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc",
),
(
["-gen-typedef-decls"],
"include/mlir/Dialect/Shape/IR/ShapeOpsTypes.h.inc",
),
(
["-gen-typedef-defs"],
"include/mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Shape/IR/ShapeOps.td",