forked from OSchip/llvm-project
501 lines
18 KiB
C++
501 lines
18 KiB
C++
//===- ExtensibleDialect.cpp - Extensible dialect ---------------*- C++ -*-===//
|
|
//
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/ExtensibleDialect.h"
|
|
#include "mlir/IR/AttributeSupport.h"
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
#include "mlir/IR/StorageUniquerSupport.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dynamic types and attributes shared functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Default parser for dynamic attribute or type parameters.
|
|
/// Parse in the format '(<>)?' or '<attr (,attr)*>'.
|
|
static LogicalResult
|
|
typeOrAttrParser(AsmParser &parser, SmallVectorImpl<Attribute> &parsedParams) {
|
|
// No parameters
|
|
if (parser.parseOptionalLess() || !parser.parseOptionalGreater())
|
|
return success();
|
|
|
|
Attribute attr;
|
|
if (parser.parseAttribute(attr))
|
|
return failure();
|
|
parsedParams.push_back(attr);
|
|
|
|
while (parser.parseOptionalGreater()) {
|
|
Attribute attr;
|
|
if (parser.parseComma() || parser.parseAttribute(attr))
|
|
return failure();
|
|
parsedParams.push_back(attr);
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Default printer for dynamic attribute or type parameters.
|
|
/// Print in the format '(<>)?' or '<attr (,attr)*>'.
|
|
static void typeOrAttrPrinter(AsmPrinter &printer, ArrayRef<Attribute> params) {
|
|
if (params.empty())
|
|
return;
|
|
|
|
printer << "<";
|
|
interleaveComma(params, printer.getStream());
|
|
printer << ">";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dynamic type
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
std::unique_ptr<DynamicTypeDefinition>
|
|
DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
|
|
VerifierFn &&verifier) {
|
|
return DynamicTypeDefinition::get(name, dialect, std::move(verifier),
|
|
typeOrAttrParser, typeOrAttrPrinter);
|
|
}
|
|
|
|
std::unique_ptr<DynamicTypeDefinition>
|
|
DynamicTypeDefinition::get(StringRef name, ExtensibleDialect *dialect,
|
|
VerifierFn &&verifier, ParserFn &&parser,
|
|
PrinterFn &&printer) {
|
|
return std::unique_ptr<DynamicTypeDefinition>(
|
|
new DynamicTypeDefinition(name, dialect, std::move(verifier),
|
|
std::move(parser), std::move(printer)));
|
|
}
|
|
|
|
DynamicTypeDefinition::DynamicTypeDefinition(StringRef nameRef,
|
|
ExtensibleDialect *dialect,
|
|
VerifierFn &&verifier,
|
|
ParserFn &&parser,
|
|
PrinterFn &&printer)
|
|
: name(nameRef), dialect(dialect), verifier(std::move(verifier)),
|
|
parser(std::move(parser)), printer(std::move(printer)),
|
|
ctx(dialect->getContext()) {}
|
|
|
|
DynamicTypeDefinition::DynamicTypeDefinition(ExtensibleDialect *dialect,
|
|
StringRef nameRef)
|
|
: name(nameRef), dialect(dialect), ctx(dialect->getContext()) {}
|
|
|
|
void DynamicTypeDefinition::registerInTypeUniquer() {
|
|
detail::TypeUniquer::registerType<DynamicType>(&getContext(), getTypeID());
|
|
}
|
|
|
|
namespace mlir {
|
|
namespace detail {
|
|
/// Storage of DynamicType.
|
|
/// Contains a pointer to the type definition and type parameters.
|
|
struct DynamicTypeStorage : public TypeStorage {
|
|
|
|
using KeyTy = std::pair<DynamicTypeDefinition *, ArrayRef<Attribute>>;
|
|
|
|
explicit DynamicTypeStorage(DynamicTypeDefinition *typeDef,
|
|
ArrayRef<Attribute> params)
|
|
: typeDef(typeDef), params(params) {}
|
|
|
|
bool operator==(const KeyTy &key) const {
|
|
return typeDef == key.first && params == key.second;
|
|
}
|
|
|
|
static llvm::hash_code hashKey(const KeyTy &key) {
|
|
return llvm::hash_value(key);
|
|
}
|
|
|
|
static DynamicTypeStorage *construct(TypeStorageAllocator &alloc,
|
|
const KeyTy &key) {
|
|
return new (alloc.allocate<DynamicTypeStorage>())
|
|
DynamicTypeStorage(key.first, alloc.copyInto(key.second));
|
|
}
|
|
|
|
/// Definition of the type.
|
|
DynamicTypeDefinition *typeDef;
|
|
|
|
/// The type parameters.
|
|
ArrayRef<Attribute> params;
|
|
};
|
|
} // namespace detail
|
|
} // namespace mlir
|
|
|
|
DynamicType DynamicType::get(DynamicTypeDefinition *typeDef,
|
|
ArrayRef<Attribute> params) {
|
|
auto &ctx = typeDef->getContext();
|
|
auto emitError = detail::getDefaultDiagnosticEmitFn(&ctx);
|
|
assert(succeeded(typeDef->verify(emitError, params)));
|
|
return detail::TypeUniquer::getWithTypeID<DynamicType>(
|
|
&ctx, typeDef->getTypeID(), typeDef, params);
|
|
}
|
|
|
|
DynamicType
|
|
DynamicType::getChecked(function_ref<InFlightDiagnostic()> emitError,
|
|
DynamicTypeDefinition *typeDef,
|
|
ArrayRef<Attribute> params) {
|
|
if (failed(typeDef->verify(emitError, params)))
|
|
return {};
|
|
auto &ctx = typeDef->getContext();
|
|
return detail::TypeUniquer::getWithTypeID<DynamicType>(
|
|
&ctx, typeDef->getTypeID(), typeDef, params);
|
|
}
|
|
|
|
DynamicTypeDefinition *DynamicType::getTypeDef() { return getImpl()->typeDef; }
|
|
|
|
ArrayRef<Attribute> DynamicType::getParams() { return getImpl()->params; }
|
|
|
|
bool DynamicType::classof(Type type) {
|
|
return type.hasTrait<TypeTrait::IsDynamicType>();
|
|
}
|
|
|
|
ParseResult DynamicType::parse(AsmParser &parser,
|
|
DynamicTypeDefinition *typeDef,
|
|
DynamicType &parsedType) {
|
|
SmallVector<Attribute> params;
|
|
if (failed(typeDef->parser(parser, params)))
|
|
return failure();
|
|
parsedType = parser.getChecked<DynamicType>(typeDef, params);
|
|
if (!parsedType)
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void DynamicType::print(AsmPrinter &printer) {
|
|
printer << getTypeDef()->getName();
|
|
getTypeDef()->printer(printer, getParams());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dynamic attribute
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
std::unique_ptr<DynamicAttrDefinition>
|
|
DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
|
|
VerifierFn &&verifier) {
|
|
return DynamicAttrDefinition::get(name, dialect, std::move(verifier),
|
|
typeOrAttrParser, typeOrAttrPrinter);
|
|
}
|
|
|
|
std::unique_ptr<DynamicAttrDefinition>
|
|
DynamicAttrDefinition::get(StringRef name, ExtensibleDialect *dialect,
|
|
VerifierFn &&verifier, ParserFn &&parser,
|
|
PrinterFn &&printer) {
|
|
return std::unique_ptr<DynamicAttrDefinition>(
|
|
new DynamicAttrDefinition(name, dialect, std::move(verifier),
|
|
std::move(parser), std::move(printer)));
|
|
}
|
|
|
|
DynamicAttrDefinition::DynamicAttrDefinition(StringRef nameRef,
|
|
ExtensibleDialect *dialect,
|
|
VerifierFn &&verifier,
|
|
ParserFn &&parser,
|
|
PrinterFn &&printer)
|
|
: name(nameRef), dialect(dialect), verifier(std::move(verifier)),
|
|
parser(std::move(parser)), printer(std::move(printer)),
|
|
ctx(dialect->getContext()) {}
|
|
|
|
DynamicAttrDefinition::DynamicAttrDefinition(ExtensibleDialect *dialect,
|
|
StringRef nameRef)
|
|
: name(nameRef), dialect(dialect), ctx(dialect->getContext()) {}
|
|
|
|
void DynamicAttrDefinition::registerInAttrUniquer() {
|
|
detail::AttributeUniquer::registerAttribute<DynamicAttr>(&getContext(),
|
|
getTypeID());
|
|
}
|
|
|
|
namespace mlir {
|
|
namespace detail {
|
|
/// Storage of DynamicAttr.
|
|
/// Contains a pointer to the attribute definition and attribute parameters.
|
|
struct DynamicAttrStorage : public AttributeStorage {
|
|
using KeyTy = std::pair<DynamicAttrDefinition *, ArrayRef<Attribute>>;
|
|
|
|
explicit DynamicAttrStorage(DynamicAttrDefinition *attrDef,
|
|
ArrayRef<Attribute> params)
|
|
: attrDef(attrDef), params(params) {}
|
|
|
|
bool operator==(const KeyTy &key) const {
|
|
return attrDef == key.first && params == key.second;
|
|
}
|
|
|
|
static llvm::hash_code hashKey(const KeyTy &key) {
|
|
return llvm::hash_value(key);
|
|
}
|
|
|
|
static DynamicAttrStorage *construct(AttributeStorageAllocator &alloc,
|
|
const KeyTy &key) {
|
|
return new (alloc.allocate<DynamicAttrStorage>())
|
|
DynamicAttrStorage(key.first, alloc.copyInto(key.second));
|
|
}
|
|
|
|
/// Definition of the type.
|
|
DynamicAttrDefinition *attrDef;
|
|
|
|
/// The type parameters.
|
|
ArrayRef<Attribute> params;
|
|
};
|
|
} // namespace detail
|
|
} // namespace mlir
|
|
|
|
DynamicAttr DynamicAttr::get(DynamicAttrDefinition *attrDef,
|
|
ArrayRef<Attribute> params) {
|
|
auto &ctx = attrDef->getContext();
|
|
return detail::AttributeUniquer::getWithTypeID<DynamicAttr>(
|
|
&ctx, attrDef->getTypeID(), attrDef, params);
|
|
}
|
|
|
|
DynamicAttr
|
|
DynamicAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
|
|
DynamicAttrDefinition *attrDef,
|
|
ArrayRef<Attribute> params) {
|
|
if (failed(attrDef->verify(emitError, params)))
|
|
return {};
|
|
return get(attrDef, params);
|
|
}
|
|
|
|
DynamicAttrDefinition *DynamicAttr::getAttrDef() { return getImpl()->attrDef; }
|
|
|
|
ArrayRef<Attribute> DynamicAttr::getParams() { return getImpl()->params; }
|
|
|
|
bool DynamicAttr::classof(Attribute attr) {
|
|
return attr.hasTrait<AttributeTrait::IsDynamicAttr>();
|
|
}
|
|
|
|
ParseResult DynamicAttr::parse(AsmParser &parser,
|
|
DynamicAttrDefinition *attrDef,
|
|
DynamicAttr &parsedAttr) {
|
|
SmallVector<Attribute> params;
|
|
if (failed(attrDef->parser(parser, params)))
|
|
return failure();
|
|
parsedAttr = parser.getChecked<DynamicAttr>(attrDef, params);
|
|
if (!parsedAttr)
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void DynamicAttr::print(AsmPrinter &printer) {
|
|
printer << getAttrDef()->getName();
|
|
getAttrDef()->printer(printer, getParams());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Dynamic operation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DynamicOpDefinition::DynamicOpDefinition(
|
|
StringRef name, ExtensibleDialect *dialect,
|
|
OperationName::VerifyInvariantsFn &&verifyFn,
|
|
OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
|
|
OperationName::ParseAssemblyFn &&parseFn,
|
|
OperationName::PrintAssemblyFn &&printFn,
|
|
OperationName::FoldHookFn &&foldHookFn,
|
|
OperationName::GetCanonicalizationPatternsFn
|
|
&&getCanonicalizationPatternsFn)
|
|
: typeID(dialect->allocateTypeID()),
|
|
name((dialect->getNamespace() + "." + name).str()), dialect(dialect),
|
|
verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
|
|
parseFn(std::move(parseFn)), printFn(std::move(printFn)),
|
|
foldHookFn(std::move(foldHookFn)),
|
|
getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) {}
|
|
|
|
std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
|
|
StringRef name, ExtensibleDialect *dialect,
|
|
OperationName::VerifyInvariantsFn &&verifyFn,
|
|
OperationName::VerifyRegionInvariantsFn &&verifyRegionFn) {
|
|
auto parseFn = [](OpAsmParser &parser, OperationState &result) {
|
|
return parser.emitError(
|
|
parser.getCurrentLocation(),
|
|
"dynamic operation do not define any parser function");
|
|
};
|
|
|
|
auto printFn = [](Operation *op, OpAsmPrinter &printer, StringRef) {
|
|
printer.printGenericOp(op);
|
|
};
|
|
|
|
return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
|
|
std::move(verifyRegionFn), std::move(parseFn),
|
|
std::move(printFn));
|
|
}
|
|
|
|
std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
|
|
StringRef name, ExtensibleDialect *dialect,
|
|
OperationName::VerifyInvariantsFn &&verifyFn,
|
|
OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
|
|
OperationName::ParseAssemblyFn &&parseFn,
|
|
OperationName::PrintAssemblyFn &&printFn) {
|
|
auto foldHookFn = [](Operation *op, ArrayRef<Attribute> operands,
|
|
SmallVectorImpl<OpFoldResult> &results) {
|
|
return failure();
|
|
};
|
|
|
|
auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
|
|
};
|
|
|
|
return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
|
|
std::move(verifyRegionFn), std::move(parseFn),
|
|
std::move(printFn), std::move(foldHookFn),
|
|
std::move(getCanonicalizationPatternsFn));
|
|
}
|
|
|
|
std::unique_ptr<DynamicOpDefinition>
|
|
DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect,
|
|
OperationName::VerifyInvariantsFn &&verifyFn,
|
|
OperationName::VerifyInvariantsFn &&verifyRegionFn,
|
|
OperationName::ParseAssemblyFn &&parseFn,
|
|
OperationName::PrintAssemblyFn &&printFn,
|
|
OperationName::FoldHookFn &&foldHookFn,
|
|
OperationName::GetCanonicalizationPatternsFn
|
|
&&getCanonicalizationPatternsFn) {
|
|
return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
|
|
name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
|
|
std::move(parseFn), std::move(printFn), std::move(foldHookFn),
|
|
std::move(getCanonicalizationPatternsFn)));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Extensible dialect
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// Interface that can only be implemented by extensible dialects.
|
|
/// The interface is used to check if a dialect is extensible or not.
|
|
class IsExtensibleDialect : public DialectInterface::Base<IsExtensibleDialect> {
|
|
public:
|
|
IsExtensibleDialect(Dialect *dialect) : Base(dialect) {}
|
|
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsExtensibleDialect)
|
|
};
|
|
} // namespace
|
|
|
|
ExtensibleDialect::ExtensibleDialect(StringRef name, MLIRContext *ctx,
|
|
TypeID typeID)
|
|
: Dialect(name, ctx, typeID) {
|
|
addInterfaces<IsExtensibleDialect>();
|
|
}
|
|
|
|
void ExtensibleDialect::registerDynamicType(
|
|
std::unique_ptr<DynamicTypeDefinition> &&type) {
|
|
DynamicTypeDefinition *typePtr = type.get();
|
|
TypeID typeID = type->getTypeID();
|
|
StringRef name = type->getName();
|
|
ExtensibleDialect *dialect = type->getDialect();
|
|
|
|
assert(dialect == this &&
|
|
"trying to register a dynamic type in the wrong dialect");
|
|
|
|
// If a type with the same name is already defined, fail.
|
|
auto registered = dynTypes.try_emplace(typeID, std::move(type)).second;
|
|
(void)registered;
|
|
assert(registered && "type TypeID was not unique");
|
|
|
|
registered = nameToDynTypes.insert({name, typePtr}).second;
|
|
(void)registered;
|
|
assert(registered &&
|
|
"Trying to create a new dynamic type with an existing name");
|
|
|
|
auto abstractType =
|
|
AbstractType::get(*dialect, DynamicAttr::getInterfaceMap(),
|
|
DynamicType::getHasTraitFn(), typeID);
|
|
|
|
/// Add the type to the dialect and the type uniquer.
|
|
addType(typeID, std::move(abstractType));
|
|
typePtr->registerInTypeUniquer();
|
|
}
|
|
|
|
void ExtensibleDialect::registerDynamicAttr(
|
|
std::unique_ptr<DynamicAttrDefinition> &&attr) {
|
|
auto *attrPtr = attr.get();
|
|
auto typeID = attr->getTypeID();
|
|
auto name = attr->getName();
|
|
auto *dialect = attr->getDialect();
|
|
|
|
assert(dialect == this &&
|
|
"trying to register a dynamic attribute in the wrong dialect");
|
|
|
|
// If an attribute with the same name is already defined, fail.
|
|
auto registered = dynAttrs.try_emplace(typeID, std::move(attr)).second;
|
|
(void)registered;
|
|
assert(registered && "attribute TypeID was not unique");
|
|
|
|
registered = nameToDynAttrs.insert({name, attrPtr}).second;
|
|
(void)registered;
|
|
assert(registered &&
|
|
"Trying to create a new dynamic attribute with an existing name");
|
|
|
|
auto abstractAttr =
|
|
AbstractAttribute::get(*dialect, DynamicAttr::getInterfaceMap(),
|
|
DynamicAttr::getHasTraitFn(), typeID);
|
|
|
|
/// Add the type to the dialect and the type uniquer.
|
|
addAttribute(typeID, std::move(abstractAttr));
|
|
attrPtr->registerInAttrUniquer();
|
|
}
|
|
|
|
void ExtensibleDialect::registerDynamicOp(
|
|
std::unique_ptr<DynamicOpDefinition> &&op) {
|
|
assert(op->dialect == this &&
|
|
"trying to register a dynamic op in the wrong dialect");
|
|
auto hasTraitFn = [](TypeID traitId) { return false; };
|
|
|
|
RegisteredOperationName::insert(
|
|
op->name, *op->dialect, op->typeID, std::move(op->parseFn),
|
|
std::move(op->printFn), std::move(op->verifyFn),
|
|
std::move(op->verifyRegionFn), std::move(op->foldHookFn),
|
|
std::move(op->getCanonicalizationPatternsFn),
|
|
detail::InterfaceMap::get<>(), std::move(hasTraitFn), {});
|
|
}
|
|
|
|
bool ExtensibleDialect::classof(const Dialect *dialect) {
|
|
return const_cast<Dialect *>(dialect)
|
|
->getRegisteredInterface<IsExtensibleDialect>();
|
|
}
|
|
|
|
OptionalParseResult ExtensibleDialect::parseOptionalDynamicType(
|
|
StringRef typeName, AsmParser &parser, Type &resultType) const {
|
|
DynamicTypeDefinition *typeDef = lookupTypeDefinition(typeName);
|
|
if (!typeDef)
|
|
return llvm::None;
|
|
|
|
DynamicType dynType;
|
|
if (DynamicType::parse(parser, typeDef, dynType))
|
|
return failure();
|
|
resultType = dynType;
|
|
return success();
|
|
}
|
|
|
|
LogicalResult ExtensibleDialect::printIfDynamicType(Type type,
|
|
AsmPrinter &printer) {
|
|
if (auto dynType = type.dyn_cast<DynamicType>()) {
|
|
dynType.print(printer);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr(
|
|
StringRef attrName, AsmParser &parser, Attribute &resultAttr) const {
|
|
DynamicAttrDefinition *attrDef = lookupAttrDefinition(attrName);
|
|
if (!attrDef)
|
|
return llvm::None;
|
|
|
|
DynamicAttr dynAttr;
|
|
if (DynamicAttr::parse(parser, attrDef, dynAttr))
|
|
return failure();
|
|
resultAttr = dynAttr;
|
|
return success();
|
|
}
|
|
|
|
LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute,
|
|
AsmPrinter &printer) {
|
|
if (auto dynAttr = attribute.dyn_cast<DynamicAttr>()) {
|
|
dynAttr.print(printer);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|