forked from OSchip/llvm-project
[mlir] Add a new `ConstantLike` trait to better identify operations that represent a "constant".
The current mechanism for identifying is a bit hacky and extremely adhoc, i.e. we explicit check 1-result, 0-operand, no side-effect, and always foldable and then assume that this is a constant. Adding a trait adds structure to this, and makes checking for a constant much more efficient as we can guarantee that all of these things have already been verified. Differential Revision: https://reviews.llvm.org/D76020
This commit is contained in:
parent
7c211cf3af
commit
907403f342
|
@ -49,7 +49,8 @@ def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>;
|
|||
// constant operation is marked as 'NoSideEffect' as it is a pure operation
|
||||
// and may be removed if dead.
|
||||
def ConstantOp : Toy_Op<"constant",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
[ConstantLike, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
// Provide a summary and description for this operation. This can be used to
|
||||
// auto-generate documentation of the operations within our dialect.
|
||||
let summary = "constant";
|
||||
|
@ -295,7 +296,7 @@ def StructAccessOp : Toy_Op<"struct_access", [NoSideEffect]> {
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def StructConstantOp : Toy_Op<"struct_constant", [NoSideEffect]> {
|
||||
def StructConstantOp : Toy_Op<"struct_constant", [ConstantLike, NoSideEffect]> {
|
||||
let summary = "struct constant";
|
||||
let description = [{
|
||||
Constant operation turns a literal struct value into an SSA value. The data
|
||||
|
|
|
@ -67,7 +67,7 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
|
||||
def SPV_ConstantOp : SPV_Op<"constant", [ConstantLike, NoSideEffect]> {
|
||||
let summary = "The op that declares a SPIR-V normal constant";
|
||||
|
||||
let description = [{
|
||||
|
|
|
@ -796,7 +796,7 @@ def CondBranchOp : Std_Op<"cond_br",
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConstantOp : Std_Op<"constant",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
|
||||
[ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
|
||||
let summary = "constant";
|
||||
|
||||
let arguments = (ins AnyAttr:$value);
|
||||
|
|
|
@ -48,8 +48,13 @@ struct attr_value_binder {
|
|||
}
|
||||
};
|
||||
|
||||
/// The matcher that matches a constant foldable operation that has no side
|
||||
/// effect, no operands and produces a single result.
|
||||
/// The matcher that matches operations that have the `ConstantLike` trait.
|
||||
struct constant_op_matcher {
|
||||
bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
|
||||
};
|
||||
|
||||
/// The matcher that matches operations that have the `ConstantLike` trait, and
|
||||
/// binds the folded attribute value.
|
||||
template <typename AttrT> struct constant_op_binder {
|
||||
AttrT *bind_value;
|
||||
|
||||
|
@ -60,20 +65,19 @@ template <typename AttrT> struct constant_op_binder {
|
|||
constant_op_binder() : bind_value(nullptr) {}
|
||||
|
||||
bool match(Operation *op) {
|
||||
if (op->getNumOperands() > 0 || op->getNumResults() != 1)
|
||||
return false;
|
||||
if (!op->hasNoSideEffect())
|
||||
if (!op->hasTrait<OpTrait::ConstantLike>())
|
||||
return false;
|
||||
|
||||
// Fold the constant to an attribute.
|
||||
SmallVector<OpFoldResult, 1> foldedOp;
|
||||
if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
|
||||
if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
|
||||
if (auto attrT = attr.dyn_cast<AttrT>()) {
|
||||
if (bind_value)
|
||||
*bind_value = attrT;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
LogicalResult result = op->fold(/*operands=*/llvm::None, foldedOp);
|
||||
(void)result;
|
||||
assert(succeeded(result) && "expected constant to be foldable");
|
||||
|
||||
if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) {
|
||||
if (bind_value)
|
||||
*bind_value = attr;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -201,8 +205,8 @@ struct RecursivePatternMatcher {
|
|||
} // end namespace detail
|
||||
|
||||
/// Matches a constant foldable operation.
|
||||
inline detail::constant_op_binder<Attribute> m_Constant() {
|
||||
return detail::constant_op_binder<Attribute>();
|
||||
inline detail::constant_op_matcher m_Constant() {
|
||||
return detail::constant_op_matcher();
|
||||
}
|
||||
|
||||
/// Matches a value from a constant foldable operation and writes the value to
|
||||
|
|
|
@ -1549,6 +1549,8 @@ def ResultsBroadcastableShape :
|
|||
def Broadcastable : NativeOpTrait<"ResultsBroadcastableShape">;
|
||||
// X op Y == Y op X
|
||||
def Commutative : NativeOpTrait<"IsCommutative">;
|
||||
// Op behaves like a constant.
|
||||
def ConstantLike : NativeOpTrait<"ConstantLike">;
|
||||
// Op behaves like a function.
|
||||
def FunctionLike : NativeOpTrait<"FunctionLike">;
|
||||
// Op is isolated from above.
|
||||
|
|
|
@ -902,6 +902,25 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// This class provides the API for a sub-set of ops that are known to be
|
||||
/// constant-like. These are non-side effecting operations with one result and
|
||||
/// zero operands that can always be folded to a specific attribute value.
|
||||
template <typename ConcreteType>
|
||||
class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
static_assert(ConcreteType::template hasTrait<OneResult>(),
|
||||
"expected operation to produce one result");
|
||||
static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
|
||||
"expected operation to take zero operands");
|
||||
// TODO: We should verify that the operation can always be folded, but this
|
||||
// requires that the attributes of the op already be verified. We should add
|
||||
// support for verifying traits "after" the operation to enable this use
|
||||
// case.
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// This class provides the API for ops that are known to be isolated from
|
||||
/// above.
|
||||
template <typename ConcreteType>
|
||||
|
|
|
@ -399,7 +399,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
|
|||
cst->erase();
|
||||
return cleanupFailure();
|
||||
}
|
||||
assert(matchPattern(constOp, m_Constant(&attr)));
|
||||
assert(matchPattern(constOp, m_Constant()));
|
||||
|
||||
generatedConstants.push_back(constOp);
|
||||
results.push_back(constOp->getResult(0));
|
||||
|
|
|
@ -57,7 +57,7 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
|
|||
// Ask the dialect to materialize a constant operation for this value.
|
||||
if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
|
||||
assert(insertPt == builder.getInsertionPoint());
|
||||
assert(matchPattern(constOp, m_Constant(&value)));
|
||||
assert(matchPattern(constOp, m_Constant()));
|
||||
return constOp;
|
||||
}
|
||||
|
||||
|
|
|
@ -454,7 +454,7 @@ func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off
|
|||
// -----
|
||||
|
||||
func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) {
|
||||
// expected-error @+1 {{'linalg.dot' op expected 3 or more operands}}
|
||||
// expected-error @+1 {{'linalg.dot' op expected 3 operands, but found 2}}
|
||||
linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ func @failedSameOperandElementType(%t1f: tensor<1xf32>, %t1i: tensor<1xi32>) {
|
|||
// -----
|
||||
|
||||
func @failedSameOperandAndResultElementType_no_operands() {
|
||||
// expected-error@+1 {{expected 1 or more operands}}
|
||||
// expected-error@+1 {{expected 2 operands, but found 0}}
|
||||
"test.same_operand_element_type"() : () -> tensor<1xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> {
|
|||
// CHECK: ArrayRef<Value> tblgen_operands;
|
||||
// CHECK: };
|
||||
|
||||
// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl
|
||||
// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::HasNoSideEffect
|
||||
// CHECK: public:
|
||||
// CHECK: using Op::Op;
|
||||
// CHECK: using OperandAdaptor = AOpOperandAdaptor;
|
||||
|
|
|
@ -1523,14 +1523,6 @@ void OpEmitter::genTraits() {
|
|||
unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
|
||||
addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
|
||||
|
||||
// Add the native and interface traits.
|
||||
for (const auto &trait : op.getTraits()) {
|
||||
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
|
||||
opClass.addTrait(opTrait->getTrait());
|
||||
else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
|
||||
opClass.addTrait(opTrait->getTrait());
|
||||
}
|
||||
|
||||
// Add variadic size trait and normal op traits.
|
||||
int numOperands = op.getNumOperands();
|
||||
int numVariadicOperands = op.getNumVariadicOperands();
|
||||
|
@ -1555,6 +1547,14 @@ void OpEmitter::genTraits() {
|
|||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the native and interface traits.
|
||||
for (const auto &trait : op.getTraits()) {
|
||||
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
|
||||
opClass.addTrait(opTrait->getTrait());
|
||||
else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
|
||||
opClass.addTrait(opTrait->getTrait());
|
||||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genOpNameGetter() {
|
||||
|
|
Loading…
Reference in New Issue