Move the definitions for CallOp and IndirectCallOp to the Op Definition Generator.

--

PiperOrigin-RevId: 247686419
This commit is contained in:
River Riddle 2019-05-10 15:27:34 -07:00 committed by Mehdi Amini
parent 77c333ca62
commit 4a6264f5c5
6 changed files with 139 additions and 143 deletions

View File

@ -309,6 +309,11 @@ def F64 : F<64>;
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
BuildableType<"getBF16Type()">;
// Function Type
// Any function type.
def FunctionType : Type<CPred<"$_self.isa<FunctionType>()">, "function type">;
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
string descr> :

View File

@ -46,77 +46,6 @@ public:
#define GET_OP_CLASSES
#include "mlir/StandardOps/Ops.h.inc"
/// The "call" operation represents a direct call to a function. The operands
/// and result types of the call must match the specified function type. The
/// callee is encoded as a function attribute named "callee".
///
/// %31 = call @my_add(%0, %1)
/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
class CallOp
: public Op<CallOp, OpTrait::VariadicOperands, OpTrait::VariadicResults> {
public:
friend Operation;
using Op::Op;
static StringRef getOperationName() { return "std.call"; }
static void build(Builder *builder, OperationState *result, Function *callee,
ArrayRef<Value *> operands);
Function *getCallee() {
return getAttrOfType<FunctionAttr>("callee").getValue();
}
/// Get the argument operands to the called function.
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}
operand_iterator arg_operand_begin() { return operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }
// Hooks to customize behavior of this op.
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
};
/// The "call_indirect" operation represents an indirect call to a value of
/// function type. Functions are first class types in MLIR, and may be passed
/// as arguments and merged together with block arguments. The operands
/// and result types of the call must match the specified function type.
///
/// %31 = call_indirect %15(%0, %1)
/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
///
class CallIndirectOp : public Op<CallIndirectOp, OpTrait::VariadicOperands,
OpTrait::VariadicResults> {
public:
friend Operation;
using Op::Op;
static StringRef getOperationName() { return "std.call_indirect"; }
static void build(Builder *builder, OperationState *result, Value *callee,
ArrayRef<Value *> operands);
Value *getCallee() { return getOperand(0); }
/// Get the argument operands to the called function.
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}
operand_iterator arg_operand_begin() { return ++operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }
// Hooks to customize behavior of this op.
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
LogicalResult verify();
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
};
/// The predicate indicates the type of the comparison to perform:
/// (in)equality; (un)signed less/greater than (or equal to).
enum class CmpIPredicate {

View File

@ -168,6 +168,87 @@ def BranchOp : Op<Standard_Dialect, "br", [Terminator]> {
}];
}
def CallOp : Op<Standard_Dialect, "call"> {
let summary = "call operation";
let description = [{
The "call" operation represents a direct call to a function. The operands
and result types of the call must match the specified function type. The
callee is encoded as a function attribute named "callee".
%2 = call @my_add(%0, %1) : (f32, f32) -> f32
}];
let arguments = (ins FunctionAttr:$callee, Variadic<AnyType>:$operands);
let results = (outs Variadic<AnyType>);
let parser = [{ return parseCallOp(parser, result); }];
let printer = [{ return printCallOp(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let builders = [OpBuilder<
"Builder *builder, OperationState *result, Function *callee,"
"ArrayRef<Value *> operands = {}", [{
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addTypes(callee->getType().getResults());
}]>];
let extraClassDeclaration = [{
Function *getCallee() {
return getAttrOfType<FunctionAttr>("callee").getValue();
}
/// Get the argument operands to the called function.
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}
operand_iterator arg_operand_begin() { return operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }
}];
}
def CallIndirectOp : Op<Standard_Dialect, "call_indirect"> {
let summary = "indirect call operation";
let description = [{
The "call_indirect" operation represents an indirect call to a value of
function type. Functions are first class types in MLIR, and may be passed
as arguments and merged together with block arguments. The operands
and result types of the call must match the specified function type.
%3 = call_indirect %2(%0, %1) : (f32, f32) -> f32
}];
let arguments = (ins FunctionType:$callee, Variadic<AnyType>:$operands);
let results = (outs Variadic<AnyType>);
let parser = [{ return parseCallIndirectOp(parser, result); }];
let printer = [{ return printCallIndirectOp(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let builders = [OpBuilder<
"Builder *, OperationState *result, Value *callee,"
"ArrayRef<Value *> operands = {}", [{
result->operands.push_back(callee);
result->addOperands(operands);
result->addTypes(callee->getType().cast<FunctionType>().getResults());
}]>];
let extraClassDeclaration = [{
Value *getCallee() { return getOperand(0); }
/// Get the argument operands to the called function.
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}
operand_iterator arg_operand_begin() { return ++operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }
}];
let hasCanonicalizer = 0b1;
}
def ConstantOp : Op<Standard_Dialect, "constant", [NoSideEffect]> {
let summary = "constant";

View File

@ -61,9 +61,8 @@ void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
: Dialect(/*name=*/"std", context) {
addOperations<CallOp, CallIndirectOp, CmpFOp, CmpIOp, CondBranchOp,
DmaStartOp, DmaWaitOp, LoadOp, MemRefCastOp, ReturnOp, SelectOp,
StoreOp, TensorCastOp,
addOperations<CmpFOp, CmpIOp, CondBranchOp, DmaStartOp, DmaWaitOp, LoadOp,
MemRefCastOp, ReturnOp, SelectOp, StoreOp, TensorCastOp,
#define GET_OP_LIST
#include "mlir/StandardOps/Ops.cpp.inc"
>();
@ -402,14 +401,7 @@ void BranchOp::eraseOperand(unsigned index) {
// CallOp
//===----------------------------------------------------------------------===//
void CallOp::build(Builder *builder, OperationState *result, Function *callee,
ArrayRef<Value *> operands) {
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addTypes(callee->getType().getResults());
}
ParseResult CallOp::parse(OpAsmParser *parser, OperationState *result) {
static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
StringRef calleeName;
llvm::SMLoc calleeLoc;
FunctionType calleeType;
@ -430,39 +422,37 @@ ParseResult CallOp::parse(OpAsmParser *parser, OperationState *result) {
return success();
}
void CallOp::print(OpAsmPrinter *p) {
static void printCallOp(OpAsmPrinter *p, CallOp op) {
*p << "call ";
p->printFunctionReference(getCallee());
p->printFunctionReference(op.getCallee());
*p << '(';
p->printOperands(getOperands());
p->printOperands(op.getOperands());
*p << ')';
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"callee"});
*p << " : " << getCallee()->getType();
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
*p << " : " << op.getCallee()->getType();
}
LogicalResult CallOp::verify() {
static LogicalResult verify(CallOp op) {
// Check that the callee attribute was specified.
auto fnAttr = getAttrOfType<FunctionAttr>("callee");
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
if (!fnAttr)
return emitOpError("requires a 'callee' function attribute");
return op.emitOpError("requires a 'callee' function attribute");
// Verify that the operand and result types match the callee.
auto fnType = fnAttr.getValue()->getType();
if (fnType.getNumInputs() != getNumOperands())
return emitOpError("incorrect number of operands for callee");
if (fnType.getNumInputs() != op.getNumOperands())
return op.emitOpError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
if (getOperand(i)->getType() != fnType.getInput(i))
return emitOpError("operand type mismatch");
}
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
if (op.getOperand(i)->getType() != fnType.getInput(i))
return op.emitOpError("operand type mismatch");
if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee");
if (fnType.getNumResults() != op.getNumResults())
return op.emitOpError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
if (getResult(i)->getType() != fnType.getResult(i))
return emitOpError("result type mismatch");
}
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
if (op.getResult(i)->getType() != fnType.getResult(i))
return op.emitOpError("result type mismatch");
return success();
}
@ -498,15 +488,8 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
};
} // end anonymous namespace.
void CallIndirectOp::build(Builder *builder, OperationState *result,
Value *callee, ArrayRef<Value *> operands) {
auto fnType = callee->getType().cast<FunctionType>();
result->operands.push_back(callee);
result->addOperands(operands);
result->addTypes(fnType.getResults());
}
ParseResult CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
static ParseResult parseCallIndirectOp(OpAsmParser *parser,
OperationState *result) {
FunctionType calleeType;
OpAsmParser::OperandType callee;
llvm::SMLoc operandsLoc;
@ -524,39 +507,37 @@ ParseResult CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
parser->addTypesToList(calleeType.getResults(), result->types));
}
void CallIndirectOp::print(OpAsmPrinter *p) {
static void printCallIndirectOp(OpAsmPrinter *p, CallIndirectOp op) {
*p << "call_indirect ";
p->printOperand(getCallee());
p->printOperand(op.getCallee());
*p << '(';
auto operandRange = getOperands();
auto operandRange = op.getOperands();
p->printOperands(++operandRange.begin(), operandRange.end());
*p << ')';
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"callee"});
*p << " : " << getCallee()->getType();
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
*p << " : " << op.getCallee()->getType();
}
LogicalResult CallIndirectOp::verify() {
static LogicalResult verify(CallIndirectOp op) {
// The callee must be a function.
auto fnType = getCallee()->getType().dyn_cast<FunctionType>();
auto fnType = op.getCallee()->getType().dyn_cast<FunctionType>();
if (!fnType)
return emitOpError("callee must have function type");
return op.emitOpError("callee must have function type");
// Verify that the operand and result types match the callee.
if (fnType.getNumInputs() != getNumOperands() - 1)
return emitOpError("incorrect number of operands for callee");
if (fnType.getNumInputs() != op.getNumOperands() - 1)
return op.emitOpError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
if (getOperand(i + 1)->getType() != fnType.getInput(i))
return emitOpError("operand type mismatch");
}
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
if (op.getOperand(i + 1)->getType() != fnType.getInput(i))
return op.emitOpError("operand type mismatch");
if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee");
if (fnType.getNumResults() != op.getNumResults())
return op.emitOpError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
if (getResult(i)->getType() != fnType.getResult(i))
return emitOpError("result type mismatch");
}
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
if (op.getResult(i)->getType() != fnType.getResult(i))
return op.emitOpError("result type mismatch");
return success();
}

View File

@ -23,12 +23,12 @@ def Z_AddOp : NS_Op<"add"> {
}
// Define rewrite patterns.
def : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
def bena : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
// CHECK-LABEL: struct GeneratedConvert0
// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("x.add", 2, context) {}
// CHECK-LABEL: struct bena
// CHECK: RewritePattern("x.add", 2, context) {}
def : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
def benb : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
// CHECK-LABEL: struct GeneratedConvert1
// CHECK: GeneratedConvert1(MLIRContext *context) : RewritePattern("x.add", 101, context) {}
// CHECK-LABEL: struct benb
// CHECK: RewritePattern("x.add", 101, context) {}

View File

@ -23,13 +23,13 @@ def OneResultOp : NS_Op<"one_result_op", []> {
let results = (outs I32:$r1);
}
def : Pattern<(ThreeResultOp $input), [
def a : Pattern<(ThreeResultOp $input), [
(OneResultOp $input),
(OneResultOp $input),
(OneResultOp $input)
]>;
// CHECK-LABEL: struct GeneratedConvert0
// CHECK-LABEL: struct a
// CHECK: void rewrite(
// CHECK: auto vOneResultOp0 = rewriter.create<OneResultOp>(
@ -37,13 +37,13 @@ def : Pattern<(ThreeResultOp $input), [
// CHECK: auto vOneResultOp2 = rewriter.create<OneResultOp>(
// CHECK: rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1, vOneResultOp2});
def : Pattern<(ThreeResultOp $input), [
def b : Pattern<(ThreeResultOp $input), [
(OneResultOp (OneResultOp:$interm $input)),
(OneResultOp $interm),
(OneResultOp (OneResultOp $interm))
]>;
// CHECK-LABEL: struct GeneratedConvert1
// CHECK-LABEL: struct b
// CHECK: void rewrite(
// CHECK: auto interm = rewriter.create<OneResultOp>(
@ -64,7 +64,7 @@ def AdditionalOp : NS_Op<"additional_one_result_op", []> {
let arguments = (ins I32:$input);
let results = (outs I32:$r1);
}
def : Pattern<(TwoResultOp $input), [
def c : Pattern<(TwoResultOp $input), [
// Additional op generated to help build the final result but not
// directly used to replace the source op
(AdditionalOp:$interm $input),
@ -73,7 +73,7 @@ def : Pattern<(TwoResultOp $input), [
(OneResultOp $input)
]>;
// CHECK-LABEL: struct GeneratedConvert2
// CHECK-LABEL: struct c
// CHECK: auto interm = rewriter.create<AdditionalOp>(
// CHECK: auto vOneResultOp0 = rewriter.create<OneResultOp>(