forked from OSchip/llvm-project
[mlir] Change CombiningKind in Vector dialect to EnumAttr.
CombiningKind was implemented before EnumAttr, so it reimplements the same behaviour with the custom code. Except for a few places, EnumAttr is a drop-in replacement. Reviewed By: nicolasvasilache, pifon2a Differential Revision: https://reviews.llvm.org/D133343
This commit is contained in:
parent
3262794804
commit
fcab0a04c5
|
@ -4,5 +4,7 @@ add_mlir_doc(VectorOps VectorOps Dialects/ -gen-op-doc)
|
|||
set(LLVM_TARGET_DEFINITIONS VectorOps.td)
|
||||
mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs)
|
||||
mlir_tablegen(VectorOpsAttrDefs.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(VectorOpsAttrDefs.cpp.inc -gen-attrdef-defs)
|
||||
add_public_tablegen_target(MLIRVectorOpsEnumsIncGen)
|
||||
add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen)
|
||||
|
|
|
@ -29,6 +29,9 @@
|
|||
// Pull in all enum type definitions and utility function declarations.
|
||||
#include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class RewritePatternSet;
|
||||
|
@ -113,22 +116,6 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
|
|||
/// chain.
|
||||
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// An attribute that specifies the combining function for `vector.contract`,
|
||||
/// and `vector.reduction`.
|
||||
class CombiningKindAttr
|
||||
: public Attribute::AttrBase<CombiningKindAttr, Attribute,
|
||||
detail::BitmaskEnumStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static CombiningKindAttr get(CombiningKind kind, MLIRContext *context);
|
||||
|
||||
CombiningKind getKind() const;
|
||||
|
||||
void print(AsmPrinter &p) const;
|
||||
static Attribute parse(AsmParser &parser, Type type);
|
||||
};
|
||||
|
||||
/// Collects patterns to progressively lower vector.broadcast ops on high-D
|
||||
/// vectors to low-D vector ops.
|
||||
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
|
||||
|
|
|
@ -57,15 +57,10 @@ def CombiningKind : I32BitEnumAttr<
|
|||
let genSpecializedAttr = 0;
|
||||
}
|
||||
|
||||
def Vector_CombiningKindAttr : DialectAttr<
|
||||
Vector_Dialect,
|
||||
CPred<"$_self.isa<::mlir::vector::CombiningKindAttr>()">,
|
||||
"Kind of combining function for contractions and reductions"> {
|
||||
let storageType = "::mlir::vector::CombiningKindAttr";
|
||||
let returnType = "::mlir::vector::CombiningKind";
|
||||
let convertFromStorage = "$_self.getKind()";
|
||||
let constBuilderCall =
|
||||
"::mlir::vector::CombiningKindAttr::get($0, $_builder.getContext())";
|
||||
/// An attribute that specifies the combining function for `vector.contract`,
|
||||
/// and `vector.reduction`.
|
||||
def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
// TODO: Add an attribute to specify a different algebra with operators other
|
||||
|
|
|
@ -30,8 +30,8 @@
|
|||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/ADT/bit.h"
|
||||
#include <numeric>
|
||||
|
||||
|
@ -227,91 +227,15 @@ struct BitmaskEnumStorage : public AttributeStorage {
|
|||
} // namespace vector
|
||||
} // namespace mlir
|
||||
|
||||
CombiningKindAttr CombiningKindAttr::get(CombiningKind kind,
|
||||
MLIRContext *context) {
|
||||
return Base::get(context, static_cast<uint64_t>(kind));
|
||||
}
|
||||
|
||||
CombiningKind CombiningKindAttr::getKind() const {
|
||||
return static_cast<CombiningKind>(getImpl()->value);
|
||||
}
|
||||
|
||||
static constexpr const CombiningKind combiningKindsList[] = {
|
||||
// clang-format off
|
||||
CombiningKind::ADD,
|
||||
CombiningKind::MUL,
|
||||
CombiningKind::MINUI,
|
||||
CombiningKind::MINSI,
|
||||
CombiningKind::MINF,
|
||||
CombiningKind::MAXUI,
|
||||
CombiningKind::MAXSI,
|
||||
CombiningKind::MAXF,
|
||||
CombiningKind::AND,
|
||||
CombiningKind::OR,
|
||||
CombiningKind::XOR,
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
void CombiningKindAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<";
|
||||
auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
|
||||
return bitEnumContains(this->getKind(), kind);
|
||||
});
|
||||
llvm::interleaveComma(kinds, printer,
|
||||
[&](auto kind) { printer << stringifyEnum(kind); });
|
||||
printer << ">";
|
||||
}
|
||||
|
||||
Attribute CombiningKindAttr::parse(AsmParser &parser, Type type) {
|
||||
if (failed(parser.parseLess()))
|
||||
return {};
|
||||
|
||||
StringRef elemName;
|
||||
if (failed(parser.parseKeyword(&elemName)))
|
||||
return {};
|
||||
|
||||
auto kind = symbolizeCombiningKind(elemName);
|
||||
if (!kind) {
|
||||
parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
|
||||
<< elemName;
|
||||
return {};
|
||||
}
|
||||
|
||||
if (failed(parser.parseGreater()))
|
||||
return {};
|
||||
|
||||
return CombiningKindAttr::get(*kind, parser.getContext());
|
||||
}
|
||||
|
||||
Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
|
||||
Type type) const {
|
||||
StringRef attrKind;
|
||||
if (parser.parseKeyword(&attrKind))
|
||||
return {};
|
||||
|
||||
if (attrKind == "kind")
|
||||
return CombiningKindAttr::parse(parser, {});
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
|
||||
return {};
|
||||
}
|
||||
|
||||
void VectorDialect::printAttribute(Attribute attr,
|
||||
DialectAsmPrinter &os) const {
|
||||
if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
|
||||
os << "kind";
|
||||
ck.print(os);
|
||||
return;
|
||||
}
|
||||
llvm_unreachable("Unknown attribute type");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void VectorDialect::initialize() {
|
||||
addAttributes<CombiningKindAttr>();
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
|
||||
>();
|
||||
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
|
@ -558,7 +482,7 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
|
|||
result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps);
|
||||
result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes);
|
||||
result.addAttribute(ContractionOp::getKindAttrStrName(),
|
||||
CombiningKindAttr::get(kind, builder.getContext()));
|
||||
CombiningKindAttr::get(builder.getContext(), kind));
|
||||
}
|
||||
|
||||
ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
|
@ -587,9 +511,10 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
result.attributes.assign(dictAttr.getValue().begin(),
|
||||
dictAttr.getValue().end());
|
||||
if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
|
||||
result.addAttribute(ContractionOp::getKindAttrStrName(),
|
||||
CombiningKindAttr::get(ContractionOp::getDefaultKind(),
|
||||
result.getContext()));
|
||||
result.addAttribute(
|
||||
ContractionOp::getKindAttrStrName(),
|
||||
CombiningKindAttr::get(result.getContext(),
|
||||
ContractionOp::getDefaultKind()));
|
||||
}
|
||||
if (masksInfo.empty())
|
||||
return success();
|
||||
|
@ -2385,8 +2310,8 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
|
||||
result.attributes.append(
|
||||
OuterProductOp::getKindAttrStrName(),
|
||||
CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
|
||||
result.getContext()));
|
||||
CombiningKindAttr::get(result.getContext(),
|
||||
OuterProductOp::getDefaultKind()));
|
||||
}
|
||||
|
||||
return failure(
|
||||
|
@ -5179,5 +5104,8 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
|
|||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
|
||||
|
|
|
@ -1111,7 +1111,8 @@ func.func @bitcast_sizemismatch(%arg0 : vector<5x1x3x2xf32>) {
|
|||
// -----
|
||||
|
||||
func.func @reduce_unknown_kind(%arg0: vector<16xf32>) -> f32 {
|
||||
// expected-error@+1 {{custom op 'vector.reduction' Unknown combining kind: joho}}
|
||||
// expected-error@+2 {{custom op 'vector.reduction' failed to parse Vector_CombiningKindAttr parameter 'value' which is to be a `::mlir::vector::CombiningKind`}}
|
||||
// expected-error@+1 {{custom op 'vector.reduction' expected ::mlir::vector::CombiningKind to be one of: }}
|
||||
%0 = vector.reduction <joho>, %arg0 : vector<16xf32> into f32
|
||||
}
|
||||
|
||||
|
|
|
@ -7758,6 +7758,14 @@ gentbl_cc_library(
|
|||
["-gen-enum-defs"],
|
||||
"include/mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc",
|
||||
),
|
||||
(
|
||||
["-gen-attrdef-decls"],
|
||||
"include/mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc",
|
||||
),
|
||||
(
|
||||
["-gen-attrdef-defs"],
|
||||
"include/mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc",
|
||||
),
|
||||
(
|
||||
["-gen-op-doc"],
|
||||
"g3doc/Dialects/Vector/VectorOps.md",
|
||||
|
|
Loading…
Reference in New Issue