forked from OSchip/llvm-project
Add support for AttrSizedOperandSegments/AttrSizedResultSegments
Certain operations can have multiple variadic operands and their size relationship is not always known statically. For such cases, we need a per-op-instance specification to divide the operands into logical groups or segments. This can be modeled by attributes. This CL introduces C++ trait AttrSizedOperandSegments for operands and AttrSizedResultSegments for results. The C++ trait just guarantees such size attribute has the correct type (1D vector) and values (non-negative), etc. It serves as the basis for ODS sugaring that with ODS argument declarations we can further verify the number of elements match the number of ODS-declared operands and we can generate handy getter methods. PiperOrigin-RevId: 282467075
This commit is contained in:
parent
174076a157
commit
13c6e419ca
|
@ -1307,8 +1307,7 @@ class NativeOpTrait<string prop> : OpTrait {
|
|||
// the value in `prop` as the trait name and the value in `params` as
|
||||
// parameters to construct the native trait class name.
|
||||
class ParamNativeOpTrait<string prop, string params>
|
||||
: NativeOpTrait<prop # "<" # params # ">::Impl"> {
|
||||
}
|
||||
: NativeOpTrait<prop # "<" # params # ">::Impl">;
|
||||
|
||||
// GenInternalOpTrait is an op trait that does not have direct C++ mapping but
|
||||
// affects op definition generator internals, like how op builders and
|
||||
|
@ -1351,7 +1350,7 @@ def Symbol : NativeOpTrait<"Symbol">;
|
|||
// Op defines a symbol table.
|
||||
def SymbolTable : NativeOpTrait<"SymbolTable">;
|
||||
// Op is a terminator.
|
||||
def Terminator : NativeOpTrait<"IsTerminator">;
|
||||
def Terminator : NativeOpTrait<"IsTerminator">;
|
||||
|
||||
// Op's regions have a single block with the specified terminator.
|
||||
class SingleBlockImplicitTerminator<string op>
|
||||
|
@ -1381,6 +1380,18 @@ def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">;
|
|||
// to have the same array size.
|
||||
def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
|
||||
|
||||
// Uses an attribute named `operand_segment_sizes` to specify how many actual
|
||||
// operand each ODS-declared operand (variadic or not) corresponds to.
|
||||
// This trait is used for ops that have multiple variadic operands but do
|
||||
// not know statically their size relationship. The attribute must be a 1D
|
||||
// vector that has the same number of elements as the number of ODS declared
|
||||
// operands. That means even if some operands are non-variadic, the attribute
|
||||
// still need to have an element for its size, which is always 1.
|
||||
def AttrSizedOperandSegments : NativeOpTrait<"AttrSizedOperandSegments">;
|
||||
// Similar to AttrSizedOperandSegments, but used for results. The attribute
|
||||
// should be named as `result_segment_sizes`.
|
||||
def AttrSizedResultSegments : NativeOpTrait<"AttrSizedResultSegments">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpInterface definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -386,6 +386,8 @@ LogicalResult verifyResultsAreBoolLike(Operation *op);
|
|||
LogicalResult verifyResultsAreFloatLike(Operation *op);
|
||||
LogicalResult verifyResultsAreIntegerLike(Operation *op);
|
||||
LogicalResult verifyIsTerminator(Operation *op);
|
||||
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
|
||||
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
|
||||
} // namespace impl
|
||||
|
||||
/// Helper class for implementing traits. Clients are not expected to interact
|
||||
|
@ -907,6 +909,43 @@ template <typename ParentOpType> struct HasParent {
|
|||
};
|
||||
};
|
||||
|
||||
/// A trait for operations that have an attribute specifying operand segments.
|
||||
///
|
||||
/// Certain operations can have multiple variadic operands and their size
|
||||
/// relationship is not always known statically. For such cases, we need
|
||||
/// a per-op-instance specification to divide the operands into logical groups
|
||||
/// or segments. This can be modeled by attributes. The attribute will be named
|
||||
/// as `operand_segment_sizes`.
|
||||
///
|
||||
/// This trait verifies the attribute for specifying operand segments has
|
||||
/// the correct type (1D vector) and values (non-negative), etc.
|
||||
template <typename ConcreteType>
|
||||
class AttrSizedOperandSegments
|
||||
: public TraitBase<ConcreteType, AttrSizedOperandSegments> {
|
||||
public:
|
||||
static StringRef getOperandSegmentSizeAttr() {
|
||||
return "operand_segment_sizes";
|
||||
}
|
||||
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
|
||||
op, getOperandSegmentSizeAttr());
|
||||
}
|
||||
};
|
||||
|
||||
/// Similar to AttrSizedOperandSegments but used for results.
|
||||
template <typename ConcreteType>
|
||||
class AttrSizedResultSegments
|
||||
: public TraitBase<ConcreteType, AttrSizedResultSegments> {
|
||||
public:
|
||||
static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
|
||||
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return ::mlir::OpTrait::impl::verifyResultSizeAttr(
|
||||
op, getResultSegmentSizeAttr());
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace OpTrait
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -136,10 +136,10 @@ public:
|
|||
Argument getArg(int index) const;
|
||||
StringRef getArgName(int index) const;
|
||||
|
||||
// Returns true if this op has the given MLIR C++ `trait`.
|
||||
// Returns the trait wrapper for the given MLIR C++ `trait`.
|
||||
// TODO: We should add a C++ wrapper class for TableGen OpTrait instead of
|
||||
// requiring the raw MLIR trait here.
|
||||
bool hasTrait(llvm::StringRef trait) const;
|
||||
const OpTrait *getTrait(llvm::StringRef trait) const;
|
||||
|
||||
// Returns the number of regions.
|
||||
unsigned getNumRegions() const;
|
||||
|
|
|
@ -957,6 +957,47 @@ LogicalResult OpTrait::impl::verifyResultsAreIntegerLike(Operation *op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
|
||||
bool isOperand) {
|
||||
auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName);
|
||||
if (!sizeAttr)
|
||||
return op->emitOpError("requires 1D vector attribute '") << attrName << "'";
|
||||
|
||||
auto sizeAttrType = sizeAttr.getType().dyn_cast<VectorType>();
|
||||
if (!sizeAttrType || sizeAttrType.getRank() != 1)
|
||||
return op->emitOpError("requires 1D vector attribute '") << attrName << "'";
|
||||
|
||||
if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) {
|
||||
return !element.isNonNegative();
|
||||
}))
|
||||
return op->emitOpError("'")
|
||||
<< attrName << "' attribute cannot have negative elements";
|
||||
|
||||
size_t totalCount = std::accumulate(
|
||||
sizeAttr.begin(), sizeAttr.end(), 0,
|
||||
[](unsigned all, APInt one) { return all + one.getZExtValue(); });
|
||||
|
||||
if (isOperand && totalCount != op->getNumOperands())
|
||||
return op->emitOpError("operand count (")
|
||||
<< op->getNumOperands() << ") does not match with the total size ("
|
||||
<< totalCount << ") specified in attribute '" << attrName << "'";
|
||||
else if (!isOperand && totalCount != op->getNumResults())
|
||||
return op->emitOpError("result count (")
|
||||
<< op->getNumResults() << ") does not match with the total size ("
|
||||
<< totalCount << ") specified in attribute '" << attrName << "'";
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op,
|
||||
StringRef attrName) {
|
||||
return verifyValueSizeAttr(op, attrName, /*isOperand=*/true);
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op,
|
||||
StringRef attrName) {
|
||||
return verifyValueSizeAttr(op, attrName, /*isOperand=*/false);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BinaryOp implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -145,20 +145,20 @@ StringRef tblgen::Operator::getArgName(int index) const {
|
|||
return argumentValues->getArgName(index)->getValue();
|
||||
}
|
||||
|
||||
bool tblgen::Operator::hasTrait(StringRef trait) const {
|
||||
for (auto t : getTraits()) {
|
||||
const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const {
|
||||
for (const auto &t : traits) {
|
||||
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
|
||||
if (opTrait->getTrait() == trait)
|
||||
return true;
|
||||
return opTrait;
|
||||
} else if (auto opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) {
|
||||
if (opTrait->getTrait() == trait)
|
||||
return true;
|
||||
return opTrait;
|
||||
} else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&t)) {
|
||||
if (opTrait->getTrait() == trait)
|
||||
return true;
|
||||
return opTrait;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned tblgen::Operator::getNumRegions() const { return regions.size(); }
|
||||
|
|
|
@ -215,3 +215,103 @@ func @failedSingleBlockImplicitTerminator_missing_terminator() {
|
|||
}) : () -> ()
|
||||
func @foo() {
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedMissingOperandSizeAttr(%arg: i32) {
|
||||
// expected-error @+1 {{requires 1D vector attribute 'operand_segment_sizes'}}
|
||||
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) : (i32, i32, i32, i32) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedOperandSizeAttrWrongType(%arg: i32) {
|
||||
// expected-error @+1 {{requires 1D vector attribute 'operand_segment_sizes'}}
|
||||
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[1, 1, 1, 1]>: tensor<4xi32>} : (i32, i32, i32, i32) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedOperandSizeAttrWrongRank(%arg: i32) {
|
||||
// expected-error @+1 {{requires 1D vector attribute 'operand_segment_sizes'}}
|
||||
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[[1, 1], [1, 1]]>: vector<2x2xi32>} : (i32, i32, i32, i32) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedOperandSizeAttrNegativeValue(%arg: i32) {
|
||||
// expected-error @+1 {{'operand_segment_sizes' attribute cannot have negative elements}}
|
||||
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[1, 1, -1, 1]>: vector<4xi32>} : (i32, i32, i32, i32) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedOperandSizeAttrWrongTotalSize(%arg: i32) {
|
||||
// expected-error @+1 {{operand count (4) does not match with the total size (3) specified in attribute 'operand_segment_sizes'}}
|
||||
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[0, 1, 1, 1]>: vector<4xi32>} : (i32, i32, i32, i32) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedOperandSizeAttrWrongCount(%arg: i32) {
|
||||
// expected-error @+1 {{'operand_segment_sizes' attribute for specifiying operand segments must have 4 elements}}
|
||||
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[2, 1, 1]>: vector<3xi32>} : (i32, i32, i32, i32) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @succeededOperandSizeAttr(%arg: i32) {
|
||||
// CHECK: test.attr_sized_operands
|
||||
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[0, 2, 1, 1]>: vector<4xi32>} : (i32, i32, i32, i32) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedMissingResultSizeAttr() {
|
||||
// expected-error @+1 {{requires 1D vector attribute 'result_segment_sizes'}}
|
||||
%0:4 = "test.attr_sized_results"() : () -> (i32, i32, i32, i32)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedResultSizeAttrWrongType() {
|
||||
// expected-error @+1 {{requires 1D vector attribute 'result_segment_sizes'}}
|
||||
%0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[1, 1, 1, 1]>: tensor<4xi32>} : () -> (i32, i32, i32, i32)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedResultSizeAttrWrongRank() {
|
||||
// expected-error @+1 {{requires 1D vector attribute 'result_segment_sizes'}}
|
||||
%0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[[1, 1], [1, 1]]>: vector<2x2xi32>} : () -> (i32, i32, i32, i32)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedResultSizeAttrNegativeValue() {
|
||||
// expected-error @+1 {{'result_segment_sizes' attribute cannot have negative elements}}
|
||||
%0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[1, 1, -1, 1]>: vector<4xi32>} : () -> (i32, i32, i32, i32)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedResultSizeAttrWrongTotalSize() {
|
||||
// expected-error @+1 {{result count (4) does not match with the total size (3) specified in attribute 'result_segment_sizes'}}
|
||||
%0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[0, 1, 1, 1]>: vector<4xi32>} : () -> (i32, i32, i32, i32)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedResultSizeAttrWrongCount() {
|
||||
// expected-error @+1 {{'result_segment_sizes' attribute for specifiying result segments must have 4 elements}}
|
||||
%0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[2, 1, 1]>: vector<3xi32>} : () -> (i32, i32, i32, i32)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @succeededResultSizeAttr() {
|
||||
// CHECK: test.attr_sized_results
|
||||
%0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[0, 2, 1, 1]>: vector<4xi32>} : () -> (i32, i32, i32, i32)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -413,6 +413,30 @@ def TestBranchOp : TEST_Op<"br", [Terminator]> {
|
|||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
}
|
||||
|
||||
def AttrSizedOperandOp : TEST_Op<"attr_sized_operands",
|
||||
[AttrSizedOperandSegments]> {
|
||||
let arguments = (ins
|
||||
Variadic<I32>:$a,
|
||||
Variadic<I32>:$b,
|
||||
I32:$c,
|
||||
Variadic<I32>:$d,
|
||||
I32ElementsAttr:$operand_segment_sizes
|
||||
);
|
||||
}
|
||||
|
||||
def AttrSizedResultOp : TEST_Op<"attr_sized_results",
|
||||
[AttrSizedResultSegments]> {
|
||||
let arguments = (ins
|
||||
I32ElementsAttr:$result_segment_sizes
|
||||
);
|
||||
let results = (outs
|
||||
Variadic<I32>:$a,
|
||||
Variadic<I32>:$b,
|
||||
I32:$c,
|
||||
Variadic<I32>:$d
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -40,18 +40,18 @@ static const char *const tblgenNamePrefix = "tblgen_";
|
|||
static const char *const generatedArgName = "tblgen_arg";
|
||||
static const char *const builderOpState = "tblgen_state";
|
||||
|
||||
// The logic to calculate the dynamic value range for an static operand/result
|
||||
// The logic to calculate the actual value range for a declared operand/result
|
||||
// of an op with variadic operands/results. Note that this logic is not for
|
||||
// general use; it assumes all variadic operands/results must have the same
|
||||
// number of values.
|
||||
//
|
||||
// {0}: The list of whether each static operand/result is variadic.
|
||||
// {0}: The list of whether each declared operand/result is variadic.
|
||||
// {1}: The total number of non-variadic operands/results.
|
||||
// {2}: The total number of variadic operands/results.
|
||||
// {3}: The total number of dynamic values.
|
||||
// {4}: The begin iterator of the dynamic values.
|
||||
// {5}: "operand" or "result"
|
||||
const char *valueRangeCalcCode = R"(
|
||||
// {3}: The total number of actual values.
|
||||
// {4}: The begin iterator of the actual values.
|
||||
// {5}: "operand" or "result".
|
||||
const char *sameVariadicSizeValueRangeCalcCode = R"(
|
||||
bool isVariadic[] = {{{0}};
|
||||
int prevVariadicCount = 0;
|
||||
for (unsigned i = 0; i < index; ++i)
|
||||
|
@ -70,6 +70,22 @@ const char *valueRangeCalcCode = R"(
|
|||
return {{std::next({4}, offset), std::next({4}, offset + size)};
|
||||
)";
|
||||
|
||||
// The logic to calculate the actual value range for a declared operand/result
|
||||
// of an op with variadic operands/results. Note that this logic is assumes
|
||||
// the op has an attribute specifying the size of each operand/result segment
|
||||
// (variadic or not).
|
||||
//
|
||||
// {0}: The name of the attribute specifying the segment sizes.
|
||||
// {1}: The begin iterator of the actual values.
|
||||
const char *attrSizedSegmentValueRangeCalcCode = R"(
|
||||
auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
|
||||
unsigned start = 0;
|
||||
for (unsigned i = 0; i < index; ++i)
|
||||
start += (*(sizeAttr.begin() + i)).getZExtValue();
|
||||
unsigned end = start + (*(sizeAttr.begin() + index)).getZExtValue();
|
||||
return {{std::next({1}, start), std::next({1}, end)};
|
||||
)";
|
||||
|
||||
static const char *const opCommentHeader = R"(
|
||||
//===----------------------------------------------------------------------===//
|
||||
// {0} {1}
|
||||
|
@ -239,6 +255,10 @@ class OpClass : public Class {
|
|||
public:
|
||||
explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
|
||||
|
||||
// Sets whether this OpClass should generate the using directive for its
|
||||
// associate operand adaptor class.
|
||||
void setHasOperandAdaptorClass(bool has);
|
||||
|
||||
// Adds an op trait.
|
||||
void addTrait(Twine trait);
|
||||
|
||||
|
@ -249,6 +269,7 @@ public:
|
|||
private:
|
||||
StringRef extraClassDeclaration;
|
||||
SmallVector<std::string, 4> traits;
|
||||
bool hasOperandAdaptor;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -401,7 +422,10 @@ void Class::writeDefTo(raw_ostream &os) const {
|
|||
}
|
||||
|
||||
OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
|
||||
: Class(name), extraClassDeclaration(extraClassDeclaration) {}
|
||||
: Class(name), extraClassDeclaration(extraClassDeclaration),
|
||||
hasOperandAdaptor(true) {}
|
||||
|
||||
void OpClass::setHasOperandAdaptorClass(bool has) { hasOperandAdaptor = has; }
|
||||
|
||||
// Adds the given trait to this op.
|
||||
void OpClass::addTrait(Twine trait) { traits.push_back(trait.str()); }
|
||||
|
@ -412,7 +436,8 @@ void OpClass::writeDeclTo(raw_ostream &os) const {
|
|||
os << ", " << trait;
|
||||
os << "> {\npublic:\n";
|
||||
os << " using Op::Op;\n";
|
||||
os << " using OperandAdaptor = " << className << "OperandAdaptor;\n";
|
||||
if (hasOperandAdaptor)
|
||||
os << " using OperandAdaptor = " << className << "OperandAdaptor;\n";
|
||||
|
||||
bool hasPrivateMethod = false;
|
||||
for (const auto &method : methods) {
|
||||
|
@ -667,12 +692,27 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
|
|||
const int numVariadicOperands = op.getNumVariadicOperands();
|
||||
const int numNormalOperands = numOperands - numVariadicOperands;
|
||||
|
||||
if (numVariadicOperands > 1 &&
|
||||
!op.hasTrait("OpTrait::SameVariadicOperandSize")) {
|
||||
const auto *sameVariadicSize =
|
||||
op.getTrait("OpTrait::SameVariadicOperandSize");
|
||||
const auto *attrSizedOperands =
|
||||
op.getTrait("OpTrait::AttrSizedOperandSegments");
|
||||
|
||||
if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
|
||||
PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
|
||||
"specification over their sizes");
|
||||
}
|
||||
|
||||
if (numVariadicOperands < 2 && attrSizedOperands) {
|
||||
PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
|
||||
"to use 'AttrSizedOperandSegments' trait");
|
||||
}
|
||||
|
||||
if (attrSizedOperands && sameVariadicSize) {
|
||||
PrintFatalError(op.getLoc(),
|
||||
"op cannot have both 'AttrSizedOperandSegments' and "
|
||||
"'SameVariadicOperandSize' traits");
|
||||
}
|
||||
|
||||
// First emit a "sink" getter method upon which we layer all nicer named
|
||||
// getter methods.
|
||||
auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
|
||||
|
@ -681,6 +721,9 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
|
|||
// We still need to match the return type, which is a range.
|
||||
m.body() << " return {std::next(" << rangeBeginCall
|
||||
<< ", index), std::next(" << rangeBeginCall << ", index + 1)};";
|
||||
} else if (attrSizedOperands) {
|
||||
m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
|
||||
"operand_segment_sizes", rangeBeginCall);
|
||||
} else {
|
||||
// Because the op can have arbitrarily interleaved variadic and non-variadic
|
||||
// operands, we need to embed a list in the "sink" getter method for
|
||||
|
@ -692,9 +735,9 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
|
|||
}
|
||||
std::string isVariadicList = llvm::join(isVariadic, ", ");
|
||||
|
||||
m.body() << formatv(valueRangeCalcCode, isVariadicList, numNormalOperands,
|
||||
numVariadicOperands, rangeSizeCall, rangeBeginCall,
|
||||
"operand");
|
||||
m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
|
||||
numNormalOperands, numVariadicOperands, rangeSizeCall,
|
||||
rangeBeginCall, "operand");
|
||||
}
|
||||
|
||||
// Then we emit nicer named getter methods by redirecting to the "sink" getter
|
||||
|
@ -716,6 +759,9 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
|
|||
}
|
||||
|
||||
void OpEmitter::genNamedOperandGetters() {
|
||||
if (op.getTrait("OpTrait::AttrSizedOperandSegments"))
|
||||
opClass.setHasOperandAdaptorClass(false);
|
||||
|
||||
generateNamedOperandGetters(
|
||||
op, opClass, /*rangeType=*/"Operation::operand_range",
|
||||
/*rangeBeginCall=*/"getOperation()->operand_begin()",
|
||||
|
@ -731,18 +777,36 @@ void OpEmitter::genNamedResultGetters() {
|
|||
// If we have more than one variadic results, we need more complicated logic
|
||||
// to calculate the value range for each result.
|
||||
|
||||
if (numVariadicResults > 1 &&
|
||||
!op.hasTrait("OpTrait::SameVariadicResultSize")) {
|
||||
const auto *sameVariadicSize = op.getTrait("OpTrait::SameVariadicResultSize");
|
||||
const auto *attrSizedResults =
|
||||
op.getTrait("OpTrait::AttrSizedResultSegments");
|
||||
|
||||
if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
|
||||
PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
|
||||
"specification over their sizes");
|
||||
}
|
||||
|
||||
if (numVariadicResults < 2 && attrSizedResults) {
|
||||
PrintFatalError(op.getLoc(), "op must have at least two variadic results "
|
||||
"to use 'AttrSizedResultSegments' trait");
|
||||
}
|
||||
|
||||
if (attrSizedResults && sameVariadicSize) {
|
||||
PrintFatalError(op.getLoc(),
|
||||
"op cannot have both 'AttrSizedResultSegments' and "
|
||||
"'SameVariadicResultSize' traits");
|
||||
}
|
||||
|
||||
auto &m = opClass.newMethod("Operation::result_range", "getODSResults",
|
||||
"unsigned index");
|
||||
|
||||
if (numVariadicResults == 0) {
|
||||
m.body() << " return {std::next(getOperation()->result_begin(), index), "
|
||||
"std::next(getOperation()->result_begin(), index + 1)};";
|
||||
} else if (attrSizedResults) {
|
||||
m.body() << formatv(attrSizedSegmentValueRangeCalcCode,
|
||||
"result_segment_sizes",
|
||||
"getOperation()->result_begin()");
|
||||
} else {
|
||||
llvm::SmallVector<StringRef, 4> isVariadic;
|
||||
isVariadic.reserve(numResults);
|
||||
|
@ -751,8 +815,9 @@ void OpEmitter::genNamedResultGetters() {
|
|||
}
|
||||
std::string isVariadicList = llvm::join(isVariadic, ", ");
|
||||
|
||||
m.body() << formatv(valueRangeCalcCode, isVariadicList, numNormalResults,
|
||||
numVariadicResults, "getOperation()->getNumResults()",
|
||||
m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
|
||||
numNormalResults, numVariadicResults,
|
||||
"getOperation()->getNumResults()",
|
||||
"getOperation()->result_begin()", "result");
|
||||
}
|
||||
|
||||
|
@ -952,11 +1017,11 @@ void OpEmitter::genBuilder() {
|
|||
// use the first operand or attribute's type as all result types
|
||||
// to facilitate different call patterns.
|
||||
if (op.getNumVariadicResults() == 0) {
|
||||
if (op.hasTrait("OpTrait::SameOperandsAndResultType")) {
|
||||
if (op.getTrait("OpTrait::SameOperandsAndResultType")) {
|
||||
genUseOperandAsResultTypeSeparateParamBuilder();
|
||||
genUseOperandAsResultTypeCollectiveParamBuilder();
|
||||
}
|
||||
if (op.hasTrait("OpTrait::FirstAttrDerivedResultType"))
|
||||
if (op.getTrait("OpTrait::FirstAttrDerivedResultType"))
|
||||
genUseAttrAsResultTypeBuilder();
|
||||
}
|
||||
}
|
||||
|
@ -1243,18 +1308,38 @@ void OpEmitter::genVerifier() {
|
|||
body << " }\n";
|
||||
}
|
||||
|
||||
genOperandResultVerifier(body, op.getOperands(), "operand");
|
||||
genOperandResultVerifier(body, op.getResults(), "result");
|
||||
const char *code = R"(
|
||||
auto sizeAttr = getAttrOfType<DenseIntElementsAttr>("{0}");
|
||||
auto numElements = sizeAttr.getType().cast<ShapedType>().getNumElements();
|
||||
if (numElements != {1}) {{
|
||||
return emitOpError("'{0}' attribute for specifiying {2} segments "
|
||||
"must have {1} elements");
|
||||
}
|
||||
)";
|
||||
|
||||
for (auto &trait : op.getTraits()) {
|
||||
if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
|
||||
if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
|
||||
body << tgfmt(" if (!($0)) {\n "
|
||||
"return emitOpError(\"failed to verify that $1\");\n }\n",
|
||||
&verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
|
||||
t->getDescription());
|
||||
} else if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
|
||||
if (t->getTrait() == "OpTrait::AttrSizedOperandSegments") {
|
||||
body << formatv(code, "operand_segment_sizes", op.getNumOperands(),
|
||||
"operand");
|
||||
} else if (t->getTrait() == "OpTrait::AttrSizedResultSegments") {
|
||||
body << formatv(code, "result_segment_sizes", op.getNumResults(),
|
||||
"result");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// These should happen after we verified the traits because
|
||||
// getODSOperands()/getODSResults() may depend on traits (e.g.,
|
||||
// AttrSizedOperandSegments/AttrSizedResultSegments).
|
||||
genOperandResultVerifier(body, op.getOperands(), "operand");
|
||||
genOperandResultVerifier(body, op.getResults(), "result");
|
||||
|
||||
genRegionVerifier(body);
|
||||
|
||||
if (hasCustomVerify) {
|
||||
|
@ -1405,7 +1490,7 @@ void OpEmitter::genOpAsmInterface() {
|
|||
// TODO: We could also add a flag to allow operations to opt in to this
|
||||
// generation, even if they only have a single operation.
|
||||
int numResults = op.getNumResults();
|
||||
if (numResults <= 1 || op.hasTrait("OpAsmOpInterface::Trait"))
|
||||
if (numResults <= 1 || op.getTrait("OpAsmOpInterface::Trait"))
|
||||
return;
|
||||
|
||||
SmallVector<StringRef, 4> resultNames(numResults);
|
||||
|
@ -1484,13 +1569,19 @@ static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
|
|||
}
|
||||
for (auto *def : defs) {
|
||||
Operator op(*def);
|
||||
const auto *attrSizedOperands =
|
||||
op.getTrait("OpTrait::AttrSizedOperandSegments");
|
||||
if (emitDecl) {
|
||||
os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
|
||||
OpOperandAdaptorEmitter::emitDecl(op, os);
|
||||
// We cannot generate the operand adaptor class if operand getters depend
|
||||
// on an attribute.
|
||||
if (!attrSizedOperands)
|
||||
OpOperandAdaptorEmitter::emitDecl(op, os);
|
||||
OpEmitter::emitDecl(op, os);
|
||||
} else {
|
||||
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
|
||||
OpOperandAdaptorEmitter::emitDef(op, os);
|
||||
if (!attrSizedOperands)
|
||||
OpOperandAdaptorEmitter::emitDef(op, os);
|
||||
OpEmitter::emitDef(op, os);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -761,8 +761,8 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
// special cases listed below, DRR needs to supply types for all results
|
||||
// when building an op.
|
||||
bool isSameOperandsAndResultType =
|
||||
resultOp.hasTrait("OpTrait::SameOperandsAndResultType");
|
||||
bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType");
|
||||
resultOp.getTrait("OpTrait::SameOperandsAndResultType");
|
||||
bool useFirstAttr = resultOp.getTrait("OpTrait::FirstAttrDerivedResultType");
|
||||
|
||||
if (isSameOperandsAndResultType || useFirstAttr) {
|
||||
// We know how to deduce the result type for ops with these traits and we've
|
||||
|
@ -780,7 +780,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
|
|||
}
|
||||
|
||||
bool isBroadcastable =
|
||||
resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult");
|
||||
resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult");
|
||||
bool usePartialResults = valuePackName != resultValue;
|
||||
|
||||
if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) {
|
||||
|
|
Loading…
Reference in New Issue