[mlir] Add a new `ConstantLike` trait to better identify operations that represent a "constant".

The current mechanism for identifying is a bit hacky and extremely adhoc, i.e. we explicit check 1-result, 0-operand, no side-effect, and always foldable and then assume that this is a constant. Adding a trait adds structure to this, and makes checking for a constant much more efficient as we can guarantee that all of these things have already been verified.

Differential Revision: https://reviews.llvm.org/D76020
This commit is contained in:
River Riddle 2020-03-12 14:06:14 -07:00
parent 7c211cf3af
commit 907403f342
12 changed files with 58 additions and 32 deletions

View File

@ -49,7 +49,8 @@ def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>;
// constant operation is marked as 'NoSideEffect' as it is a pure operation
// and may be removed if dead.
def ConstantOp : Toy_Op<"constant",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
// Provide a summary and description for this operation. This can be used to
// auto-generate documentation of the operations within our dialect.
let summary = "constant";
@ -295,7 +296,7 @@ def StructAccessOp : Toy_Op<"struct_access", [NoSideEffect]> {
let hasFolder = 1;
}
def StructConstantOp : Toy_Op<"struct_constant", [NoSideEffect]> {
def StructConstantOp : Toy_Op<"struct_constant", [ConstantLike, NoSideEffect]> {
let summary = "struct constant";
let description = [{
Constant operation turns a literal struct value into an SSA value. The data

View File

@ -67,7 +67,7 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
// -----
def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
def SPV_ConstantOp : SPV_Op<"constant", [ConstantLike, NoSideEffect]> {
let summary = "The op that declares a SPIR-V normal constant";
let description = [{

View File

@ -796,7 +796,7 @@ def CondBranchOp : Std_Op<"cond_br",
//===----------------------------------------------------------------------===//
def ConstantOp : Std_Op<"constant",
[NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
[ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
let summary = "constant";
let arguments = (ins AnyAttr:$value);

View File

@ -48,8 +48,13 @@ struct attr_value_binder {
}
};
/// The matcher that matches a constant foldable operation that has no side
/// effect, no operands and produces a single result.
/// The matcher that matches operations that have the `ConstantLike` trait.
struct constant_op_matcher {
bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
};
/// The matcher that matches operations that have the `ConstantLike` trait, and
/// binds the folded attribute value.
template <typename AttrT> struct constant_op_binder {
AttrT *bind_value;
@ -60,20 +65,19 @@ template <typename AttrT> struct constant_op_binder {
constant_op_binder() : bind_value(nullptr) {}
bool match(Operation *op) {
if (op->getNumOperands() > 0 || op->getNumResults() != 1)
return false;
if (!op->hasNoSideEffect())
if (!op->hasTrait<OpTrait::ConstantLike>())
return false;
// Fold the constant to an attribute.
SmallVector<OpFoldResult, 1> foldedOp;
if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
if (auto attrT = attr.dyn_cast<AttrT>()) {
if (bind_value)
*bind_value = attrT;
return true;
}
}
LogicalResult result = op->fold(/*operands=*/llvm::None, foldedOp);
(void)result;
assert(succeeded(result) && "expected constant to be foldable");
if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) {
if (bind_value)
*bind_value = attr;
return true;
}
return false;
}
@ -201,8 +205,8 @@ struct RecursivePatternMatcher {
} // end namespace detail
/// Matches a constant foldable operation.
inline detail::constant_op_binder<Attribute> m_Constant() {
return detail::constant_op_binder<Attribute>();
inline detail::constant_op_matcher m_Constant() {
return detail::constant_op_matcher();
}
/// Matches a value from a constant foldable operation and writes the value to

View File

@ -1549,6 +1549,8 @@ def ResultsBroadcastableShape :
def Broadcastable : NativeOpTrait<"ResultsBroadcastableShape">;
// X op Y == Y op X
def Commutative : NativeOpTrait<"IsCommutative">;
// Op behaves like a constant.
def ConstantLike : NativeOpTrait<"ConstantLike">;
// Op behaves like a function.
def FunctionLike : NativeOpTrait<"FunctionLike">;
// Op is isolated from above.

View File

@ -902,6 +902,25 @@ public:
}
};
/// This class provides the API for a sub-set of ops that are known to be
/// constant-like. These are non-side effecting operations with one result and
/// zero operands that can always be folded to a specific attribute value.
template <typename ConcreteType>
class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(ConcreteType::template hasTrait<OneResult>(),
"expected operation to produce one result");
static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
"expected operation to take zero operands");
// TODO: We should verify that the operation can always be folded, but this
// requires that the attributes of the op already be verified. We should add
// support for verifying traits "after" the operation to enable this use
// case.
return success();
}
};
/// This class provides the API for ops that are known to be isolated from
/// above.
template <typename ConcreteType>

View File

@ -399,7 +399,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
cst->erase();
return cleanupFailure();
}
assert(matchPattern(constOp, m_Constant(&attr)));
assert(matchPattern(constOp, m_Constant()));
generatedConstants.push_back(constOp);
results.push_back(constOp->getResult(0));

View File

@ -57,7 +57,7 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
// Ask the dialect to materialize a constant operation for this value.
if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
assert(insertPt == builder.getInsertionPoint());
assert(matchPattern(constOp, m_Constant(&value)));
assert(matchPattern(constOp, m_Constant()));
return constOp;
}

View File

@ -454,7 +454,7 @@ func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off
// -----
func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) {
// expected-error @+1 {{'linalg.dot' op expected 3 or more operands}}
// expected-error @+1 {{'linalg.dot' op expected 3 operands, but found 2}}
linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32>
}

View File

@ -24,7 +24,7 @@ func @failedSameOperandElementType(%t1f: tensor<1xf32>, %t1i: tensor<1xi32>) {
// -----
func @failedSameOperandAndResultElementType_no_operands() {
// expected-error@+1 {{expected 1 or more operands}}
// expected-error@+1 {{expected 2 operands, but found 0}}
"test.same_operand_element_type"() : () -> tensor<1xf32>
}

View File

@ -55,7 +55,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect, NoSideEffect]> {
// CHECK: ArrayRef<Value> tblgen_operands;
// CHECK: };
// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl
// CHECK: class AOp : public Op<AOp, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::HasNoSideEffect
// CHECK: public:
// CHECK: using Op::Op;
// CHECK: using OperandAdaptor = AOpOperandAdaptor;

View File

@ -1523,14 +1523,6 @@ void OpEmitter::genTraits() {
unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
// Add the native and interface traits.
for (const auto &trait : op.getTraits()) {
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
opClass.addTrait(opTrait->getTrait());
else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
opClass.addTrait(opTrait->getTrait());
}
// Add variadic size trait and normal op traits.
int numOperands = op.getNumOperands();
int numVariadicOperands = op.getNumVariadicOperands();
@ -1555,6 +1547,14 @@ void OpEmitter::genTraits() {
break;
}
}
// Add the native and interface traits.
for (const auto &trait : op.getTraits()) {
if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
opClass.addTrait(opTrait->getTrait());
else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
opClass.addTrait(opTrait->getTrait());
}
}
void OpEmitter::genOpNameGetter() {