[mlir] Change CombiningKind in Vector dialect to EnumAttr.

CombiningKind was implemented before EnumAttr, so it reimplements the same behaviour with the custom code. Except for a few places, EnumAttr is a drop-in replacement.

Reviewed By: nicolasvasilache, pifon2a

Differential Revision: https://reviews.llvm.org/D133343
This commit is contained in:
Oleg Shyshkov 2022-09-07 13:33:02 +02:00 committed by Alexander Belyaev
parent 3262794804
commit fcab0a04c5
6 changed files with 34 additions and 113 deletions

View File

@ -4,5 +4,7 @@ add_mlir_doc(VectorOps VectorOps Dialects/ -gen-op-doc)
set(LLVM_TARGET_DEFINITIONS VectorOps.td)
mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(VectorOpsAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(VectorOpsAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRVectorOpsEnumsIncGen)
add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen)

View File

@ -29,6 +29,9 @@
// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc"
namespace mlir {
class MLIRContext;
class RewritePatternSet;
@ -113,22 +116,6 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
/// chain.
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
class CombiningKindAttr
: public Attribute::AttrBase<CombiningKindAttr, Attribute,
detail::BitmaskEnumStorage> {
public:
using Base::Base;
static CombiningKindAttr get(CombiningKind kind, MLIRContext *context);
CombiningKind getKind() const;
void print(AsmPrinter &p) const;
static Attribute parse(AsmParser &parser, Type type);
};
/// Collects patterns to progressively lower vector.broadcast ops on high-D
/// vectors to low-D vector ops.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);

View File

@ -57,15 +57,10 @@ def CombiningKind : I32BitEnumAttr<
let genSpecializedAttr = 0;
}
def Vector_CombiningKindAttr : DialectAttr<
Vector_Dialect,
CPred<"$_self.isa<::mlir::vector::CombiningKindAttr>()">,
"Kind of combining function for contractions and reductions"> {
let storageType = "::mlir::vector::CombiningKindAttr";
let returnType = "::mlir::vector::CombiningKind";
let convertFromStorage = "$_self.getKind()";
let constBuilderCall =
"::mlir::vector::CombiningKindAttr::get($0, $_builder.getContext())";
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
let assemblyFormat = "`<` $value `>`";
}
// TODO: Add an attribute to specify a different algebra with operators other

View File

@ -30,8 +30,8 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include <numeric>
@ -227,91 +227,15 @@ struct BitmaskEnumStorage : public AttributeStorage {
} // namespace vector
} // namespace mlir
CombiningKindAttr CombiningKindAttr::get(CombiningKind kind,
MLIRContext *context) {
return Base::get(context, static_cast<uint64_t>(kind));
}
CombiningKind CombiningKindAttr::getKind() const {
return static_cast<CombiningKind>(getImpl()->value);
}
static constexpr const CombiningKind combiningKindsList[] = {
// clang-format off
CombiningKind::ADD,
CombiningKind::MUL,
CombiningKind::MINUI,
CombiningKind::MINSI,
CombiningKind::MINF,
CombiningKind::MAXUI,
CombiningKind::MAXSI,
CombiningKind::MAXF,
CombiningKind::AND,
CombiningKind::OR,
CombiningKind::XOR,
// clang-format on
};
void CombiningKindAttr::print(AsmPrinter &printer) const {
printer << "<";
auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
return bitEnumContains(this->getKind(), kind);
});
llvm::interleaveComma(kinds, printer,
[&](auto kind) { printer << stringifyEnum(kind); });
printer << ">";
}
Attribute CombiningKindAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
StringRef elemName;
if (failed(parser.parseKeyword(&elemName)))
return {};
auto kind = symbolizeCombiningKind(elemName);
if (!kind) {
parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
<< elemName;
return {};
}
if (failed(parser.parseGreater()))
return {};
return CombiningKindAttr::get(*kind, parser.getContext());
}
Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
Type type) const {
StringRef attrKind;
if (parser.parseKeyword(&attrKind))
return {};
if (attrKind == "kind")
return CombiningKindAttr::parse(parser, {});
parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
return {};
}
void VectorDialect::printAttribute(Attribute attr,
DialectAsmPrinter &os) const {
if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
os << "kind";
ck.print(os);
return;
}
llvm_unreachable("Unknown attribute type");
}
//===----------------------------------------------------------------------===//
// VectorDialect
//===----------------------------------------------------------------------===//
void VectorDialect::initialize() {
addAttributes<CombiningKindAttr>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
@ -558,7 +482,7 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps);
result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes);
result.addAttribute(ContractionOp::getKindAttrStrName(),
CombiningKindAttr::get(kind, builder.getContext()));
CombiningKindAttr::get(builder.getContext(), kind));
}
ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
@ -587,9 +511,10 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
result.attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
result.addAttribute(ContractionOp::getKindAttrStrName(),
CombiningKindAttr::get(ContractionOp::getDefaultKind(),
result.getContext()));
result.addAttribute(
ContractionOp::getKindAttrStrName(),
CombiningKindAttr::get(result.getContext(),
ContractionOp::getDefaultKind()));
}
if (masksInfo.empty())
return success();
@ -2385,8 +2310,8 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
result.attributes.append(
OuterProductOp::getKindAttrStrName(),
CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
result.getContext()));
CombiningKindAttr::get(result.getContext(),
OuterProductOp::getDefaultKind()));
}
return failure(
@ -5179,5 +5104,8 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"

View File

@ -1111,7 +1111,8 @@ func.func @bitcast_sizemismatch(%arg0 : vector<5x1x3x2xf32>) {
// -----
func.func @reduce_unknown_kind(%arg0: vector<16xf32>) -> f32 {
// expected-error@+1 {{custom op 'vector.reduction' Unknown combining kind: joho}}
// expected-error@+2 {{custom op 'vector.reduction' failed to parse Vector_CombiningKindAttr parameter 'value' which is to be a `::mlir::vector::CombiningKind`}}
// expected-error@+1 {{custom op 'vector.reduction' expected ::mlir::vector::CombiningKind to be one of: }}
%0 = vector.reduction <joho>, %arg0 : vector<16xf32> into f32
}

View File

@ -7758,6 +7758,14 @@ gentbl_cc_library(
["-gen-enum-defs"],
"include/mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc",
),
(
["-gen-attrdef-decls"],
"include/mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc",
),
(
["-gen-attrdef-defs"],
"include/mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc",
),
(
["-gen-op-doc"],
"g3doc/Dialects/Vector/VectorOps.md",