forked from OSchip/llvm-project
Move the definitions for CallOp and IndirectCallOp to the Op Definition Generator.
-- PiperOrigin-RevId: 247686419
This commit is contained in:
parent
77c333ca62
commit
4a6264f5c5
|
@ -309,6 +309,11 @@ def F64 : F<64>;
|
|||
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
|
||||
BuildableType<"getBF16Type()">;
|
||||
|
||||
// Function Type
|
||||
|
||||
// Any function type.
|
||||
def FunctionType : Type<CPred<"$_self.isa<FunctionType>()">, "function type">;
|
||||
|
||||
// A container type is a type that has another type embedded within it.
|
||||
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
|
||||
string descr> :
|
||||
|
|
|
@ -46,77 +46,6 @@ public:
|
|||
#define GET_OP_CLASSES
|
||||
#include "mlir/StandardOps/Ops.h.inc"
|
||||
|
||||
/// The "call" operation represents a direct call to a function. The operands
|
||||
/// and result types of the call must match the specified function type. The
|
||||
/// callee is encoded as a function attribute named "callee".
|
||||
///
|
||||
/// %31 = call @my_add(%0, %1)
|
||||
/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
|
||||
class CallOp
|
||||
: public Op<CallOp, OpTrait::VariadicOperands, OpTrait::VariadicResults> {
|
||||
public:
|
||||
friend Operation;
|
||||
using Op::Op;
|
||||
|
||||
static StringRef getOperationName() { return "std.call"; }
|
||||
|
||||
static void build(Builder *builder, OperationState *result, Function *callee,
|
||||
ArrayRef<Value *> operands);
|
||||
|
||||
Function *getCallee() {
|
||||
return getAttrOfType<FunctionAttr>("callee").getValue();
|
||||
}
|
||||
|
||||
/// Get the argument operands to the called function.
|
||||
operand_range getArgOperands() {
|
||||
return {arg_operand_begin(), arg_operand_end()};
|
||||
}
|
||||
|
||||
operand_iterator arg_operand_begin() { return operand_begin(); }
|
||||
operand_iterator arg_operand_end() { return operand_end(); }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p);
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
/// The "call_indirect" operation represents an indirect call to a value of
|
||||
/// function type. Functions are first class types in MLIR, and may be passed
|
||||
/// as arguments and merged together with block arguments. The operands
|
||||
/// and result types of the call must match the specified function type.
|
||||
///
|
||||
/// %31 = call_indirect %15(%0, %1)
|
||||
/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
|
||||
///
|
||||
class CallIndirectOp : public Op<CallIndirectOp, OpTrait::VariadicOperands,
|
||||
OpTrait::VariadicResults> {
|
||||
public:
|
||||
friend Operation;
|
||||
using Op::Op;
|
||||
static StringRef getOperationName() { return "std.call_indirect"; }
|
||||
|
||||
static void build(Builder *builder, OperationState *result, Value *callee,
|
||||
ArrayRef<Value *> operands);
|
||||
|
||||
Value *getCallee() { return getOperand(0); }
|
||||
|
||||
/// Get the argument operands to the called function.
|
||||
operand_range getArgOperands() {
|
||||
return {arg_operand_begin(), arg_operand_end()};
|
||||
}
|
||||
|
||||
operand_iterator arg_operand_begin() { return ++operand_begin(); }
|
||||
operand_iterator arg_operand_end() { return operand_end(); }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p);
|
||||
LogicalResult verify();
|
||||
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context);
|
||||
};
|
||||
|
||||
/// The predicate indicates the type of the comparison to perform:
|
||||
/// (in)equality; (un)signed less/greater than (or equal to).
|
||||
enum class CmpIPredicate {
|
||||
|
|
|
@ -168,6 +168,87 @@ def BranchOp : Op<Standard_Dialect, "br", [Terminator]> {
|
|||
}];
|
||||
}
|
||||
|
||||
def CallOp : Op<Standard_Dialect, "call"> {
|
||||
let summary = "call operation";
|
||||
let description = [{
|
||||
The "call" operation represents a direct call to a function. The operands
|
||||
and result types of the call must match the specified function type. The
|
||||
callee is encoded as a function attribute named "callee".
|
||||
|
||||
%2 = call @my_add(%0, %1) : (f32, f32) -> f32
|
||||
}];
|
||||
|
||||
let arguments = (ins FunctionAttr:$callee, Variadic<AnyType>:$operands);
|
||||
let results = (outs Variadic<AnyType>);
|
||||
|
||||
let parser = [{ return parseCallOp(parser, result); }];
|
||||
let printer = [{ return printCallOp(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState *result, Function *callee,"
|
||||
"ArrayRef<Value *> operands = {}", [{
|
||||
result->addOperands(operands);
|
||||
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
||||
result->addTypes(callee->getType().getResults());
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Function *getCallee() {
|
||||
return getAttrOfType<FunctionAttr>("callee").getValue();
|
||||
}
|
||||
|
||||
/// Get the argument operands to the called function.
|
||||
operand_range getArgOperands() {
|
||||
return {arg_operand_begin(), arg_operand_end()};
|
||||
}
|
||||
|
||||
operand_iterator arg_operand_begin() { return operand_begin(); }
|
||||
operand_iterator arg_operand_end() { return operand_end(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def CallIndirectOp : Op<Standard_Dialect, "call_indirect"> {
|
||||
let summary = "indirect call operation";
|
||||
let description = [{
|
||||
The "call_indirect" operation represents an indirect call to a value of
|
||||
function type. Functions are first class types in MLIR, and may be passed
|
||||
as arguments and merged together with block arguments. The operands
|
||||
and result types of the call must match the specified function type.
|
||||
|
||||
%3 = call_indirect %2(%0, %1) : (f32, f32) -> f32
|
||||
}];
|
||||
|
||||
let arguments = (ins FunctionType:$callee, Variadic<AnyType>:$operands);
|
||||
let results = (outs Variadic<AnyType>);
|
||||
|
||||
let parser = [{ return parseCallIndirectOp(parser, result); }];
|
||||
let printer = [{ return printCallIndirectOp(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *, OperationState *result, Value *callee,"
|
||||
"ArrayRef<Value *> operands = {}", [{
|
||||
result->operands.push_back(callee);
|
||||
result->addOperands(operands);
|
||||
result->addTypes(callee->getType().cast<FunctionType>().getResults());
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Value *getCallee() { return getOperand(0); }
|
||||
|
||||
/// Get the argument operands to the called function.
|
||||
operand_range getArgOperands() {
|
||||
return {arg_operand_begin(), arg_operand_end()};
|
||||
}
|
||||
|
||||
operand_iterator arg_operand_begin() { return ++operand_begin(); }
|
||||
operand_iterator arg_operand_end() { return operand_end(); }
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 0b1;
|
||||
}
|
||||
|
||||
def ConstantOp : Op<Standard_Dialect, "constant", [NoSideEffect]> {
|
||||
let summary = "constant";
|
||||
|
||||
|
|
|
@ -61,9 +61,8 @@ void detail::printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
|
|||
|
||||
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
|
||||
: Dialect(/*name=*/"std", context) {
|
||||
addOperations<CallOp, CallIndirectOp, CmpFOp, CmpIOp, CondBranchOp,
|
||||
DmaStartOp, DmaWaitOp, LoadOp, MemRefCastOp, ReturnOp, SelectOp,
|
||||
StoreOp, TensorCastOp,
|
||||
addOperations<CmpFOp, CmpIOp, CondBranchOp, DmaStartOp, DmaWaitOp, LoadOp,
|
||||
MemRefCastOp, ReturnOp, SelectOp, StoreOp, TensorCastOp,
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/StandardOps/Ops.cpp.inc"
|
||||
>();
|
||||
|
@ -402,14 +401,7 @@ void BranchOp::eraseOperand(unsigned index) {
|
|||
// CallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void CallOp::build(Builder *builder, OperationState *result, Function *callee,
|
||||
ArrayRef<Value *> operands) {
|
||||
result->addOperands(operands);
|
||||
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
||||
result->addTypes(callee->getType().getResults());
|
||||
}
|
||||
|
||||
ParseResult CallOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
|
||||
StringRef calleeName;
|
||||
llvm::SMLoc calleeLoc;
|
||||
FunctionType calleeType;
|
||||
|
@ -430,39 +422,37 @@ ParseResult CallOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
return success();
|
||||
}
|
||||
|
||||
void CallOp::print(OpAsmPrinter *p) {
|
||||
static void printCallOp(OpAsmPrinter *p, CallOp op) {
|
||||
*p << "call ";
|
||||
p->printFunctionReference(getCallee());
|
||||
p->printFunctionReference(op.getCallee());
|
||||
*p << '(';
|
||||
p->printOperands(getOperands());
|
||||
p->printOperands(op.getOperands());
|
||||
*p << ')';
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"callee"});
|
||||
*p << " : " << getCallee()->getType();
|
||||
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
|
||||
*p << " : " << op.getCallee()->getType();
|
||||
}
|
||||
|
||||
LogicalResult CallOp::verify() {
|
||||
static LogicalResult verify(CallOp op) {
|
||||
// Check that the callee attribute was specified.
|
||||
auto fnAttr = getAttrOfType<FunctionAttr>("callee");
|
||||
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
|
||||
if (!fnAttr)
|
||||
return emitOpError("requires a 'callee' function attribute");
|
||||
return op.emitOpError("requires a 'callee' function attribute");
|
||||
|
||||
// Verify that the operand and result types match the callee.
|
||||
auto fnType = fnAttr.getValue()->getType();
|
||||
if (fnType.getNumInputs() != getNumOperands())
|
||||
return emitOpError("incorrect number of operands for callee");
|
||||
if (fnType.getNumInputs() != op.getNumOperands())
|
||||
return op.emitOpError("incorrect number of operands for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
|
||||
if (getOperand(i)->getType() != fnType.getInput(i))
|
||||
return emitOpError("operand type mismatch");
|
||||
}
|
||||
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
|
||||
if (op.getOperand(i)->getType() != fnType.getInput(i))
|
||||
return op.emitOpError("operand type mismatch");
|
||||
|
||||
if (fnType.getNumResults() != getNumResults())
|
||||
return emitOpError("incorrect number of results for callee");
|
||||
if (fnType.getNumResults() != op.getNumResults())
|
||||
return op.emitOpError("incorrect number of results for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
|
||||
if (getResult(i)->getType() != fnType.getResult(i))
|
||||
return emitOpError("result type mismatch");
|
||||
}
|
||||
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
|
||||
if (op.getResult(i)->getType() != fnType.getResult(i))
|
||||
return op.emitOpError("result type mismatch");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -498,15 +488,8 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
|
|||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void CallIndirectOp::build(Builder *builder, OperationState *result,
|
||||
Value *callee, ArrayRef<Value *> operands) {
|
||||
auto fnType = callee->getType().cast<FunctionType>();
|
||||
result->operands.push_back(callee);
|
||||
result->addOperands(operands);
|
||||
result->addTypes(fnType.getResults());
|
||||
}
|
||||
|
||||
ParseResult CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
static ParseResult parseCallIndirectOp(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
FunctionType calleeType;
|
||||
OpAsmParser::OperandType callee;
|
||||
llvm::SMLoc operandsLoc;
|
||||
|
@ -524,39 +507,37 @@ ParseResult CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
parser->addTypesToList(calleeType.getResults(), result->types));
|
||||
}
|
||||
|
||||
void CallIndirectOp::print(OpAsmPrinter *p) {
|
||||
static void printCallIndirectOp(OpAsmPrinter *p, CallIndirectOp op) {
|
||||
*p << "call_indirect ";
|
||||
p->printOperand(getCallee());
|
||||
p->printOperand(op.getCallee());
|
||||
*p << '(';
|
||||
auto operandRange = getOperands();
|
||||
auto operandRange = op.getOperands();
|
||||
p->printOperands(++operandRange.begin(), operandRange.end());
|
||||
*p << ')';
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"callee"});
|
||||
*p << " : " << getCallee()->getType();
|
||||
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
|
||||
*p << " : " << op.getCallee()->getType();
|
||||
}
|
||||
|
||||
LogicalResult CallIndirectOp::verify() {
|
||||
static LogicalResult verify(CallIndirectOp op) {
|
||||
// The callee must be a function.
|
||||
auto fnType = getCallee()->getType().dyn_cast<FunctionType>();
|
||||
auto fnType = op.getCallee()->getType().dyn_cast<FunctionType>();
|
||||
if (!fnType)
|
||||
return emitOpError("callee must have function type");
|
||||
return op.emitOpError("callee must have function type");
|
||||
|
||||
// Verify that the operand and result types match the callee.
|
||||
if (fnType.getNumInputs() != getNumOperands() - 1)
|
||||
return emitOpError("incorrect number of operands for callee");
|
||||
if (fnType.getNumInputs() != op.getNumOperands() - 1)
|
||||
return op.emitOpError("incorrect number of operands for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
|
||||
if (getOperand(i + 1)->getType() != fnType.getInput(i))
|
||||
return emitOpError("operand type mismatch");
|
||||
}
|
||||
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
|
||||
if (op.getOperand(i + 1)->getType() != fnType.getInput(i))
|
||||
return op.emitOpError("operand type mismatch");
|
||||
|
||||
if (fnType.getNumResults() != getNumResults())
|
||||
return emitOpError("incorrect number of results for callee");
|
||||
if (fnType.getNumResults() != op.getNumResults())
|
||||
return op.emitOpError("incorrect number of results for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
|
||||
if (getResult(i)->getType() != fnType.getResult(i))
|
||||
return emitOpError("result type mismatch");
|
||||
}
|
||||
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
|
||||
if (op.getResult(i)->getType() != fnType.getResult(i))
|
||||
return op.emitOpError("result type mismatch");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -23,12 +23,12 @@ def Z_AddOp : NS_Op<"add"> {
|
|||
}
|
||||
|
||||
// Define rewrite patterns.
|
||||
def : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
|
||||
def bena : Pat<(X_AddOp (X_AddOp $lhs, $rhs), $rhs), (Y_AddOp $lhs, $rhs, $rhs)>;
|
||||
|
||||
// CHECK-LABEL: struct GeneratedConvert0
|
||||
// CHECK: GeneratedConvert0(MLIRContext *context) : RewritePattern("x.add", 2, context) {}
|
||||
// CHECK-LABEL: struct bena
|
||||
// CHECK: RewritePattern("x.add", 2, context) {}
|
||||
|
||||
def : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
|
||||
def benb : Pat<(X_AddOp $lhs, $rhs), (Z_AddOp $lhs), [(IfEqual $lhs, $rhs)], (addBenefit 100)>;
|
||||
|
||||
// CHECK-LABEL: struct GeneratedConvert1
|
||||
// CHECK: GeneratedConvert1(MLIRContext *context) : RewritePattern("x.add", 101, context) {}
|
||||
// CHECK-LABEL: struct benb
|
||||
// CHECK: RewritePattern("x.add", 101, context) {}
|
||||
|
|
|
@ -23,13 +23,13 @@ def OneResultOp : NS_Op<"one_result_op", []> {
|
|||
let results = (outs I32:$r1);
|
||||
}
|
||||
|
||||
def : Pattern<(ThreeResultOp $input), [
|
||||
def a : Pattern<(ThreeResultOp $input), [
|
||||
(OneResultOp $input),
|
||||
(OneResultOp $input),
|
||||
(OneResultOp $input)
|
||||
]>;
|
||||
|
||||
// CHECK-LABEL: struct GeneratedConvert0
|
||||
// CHECK-LABEL: struct a
|
||||
|
||||
// CHECK: void rewrite(
|
||||
// CHECK: auto vOneResultOp0 = rewriter.create<OneResultOp>(
|
||||
|
@ -37,13 +37,13 @@ def : Pattern<(ThreeResultOp $input), [
|
|||
// CHECK: auto vOneResultOp2 = rewriter.create<OneResultOp>(
|
||||
// CHECK: rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1, vOneResultOp2});
|
||||
|
||||
def : Pattern<(ThreeResultOp $input), [
|
||||
def b : Pattern<(ThreeResultOp $input), [
|
||||
(OneResultOp (OneResultOp:$interm $input)),
|
||||
(OneResultOp $interm),
|
||||
(OneResultOp (OneResultOp $interm))
|
||||
]>;
|
||||
|
||||
// CHECK-LABEL: struct GeneratedConvert1
|
||||
// CHECK-LABEL: struct b
|
||||
|
||||
// CHECK: void rewrite(
|
||||
// CHECK: auto interm = rewriter.create<OneResultOp>(
|
||||
|
@ -64,7 +64,7 @@ def AdditionalOp : NS_Op<"additional_one_result_op", []> {
|
|||
let arguments = (ins I32:$input);
|
||||
let results = (outs I32:$r1);
|
||||
}
|
||||
def : Pattern<(TwoResultOp $input), [
|
||||
def c : Pattern<(TwoResultOp $input), [
|
||||
// Additional op generated to help build the final result but not
|
||||
// directly used to replace the source op
|
||||
(AdditionalOp:$interm $input),
|
||||
|
@ -73,7 +73,7 @@ def : Pattern<(TwoResultOp $input), [
|
|||
(OneResultOp $input)
|
||||
]>;
|
||||
|
||||
// CHECK-LABEL: struct GeneratedConvert2
|
||||
// CHECK-LABEL: struct c
|
||||
|
||||
// CHECK: auto interm = rewriter.create<AdditionalOp>(
|
||||
// CHECK: auto vOneResultOp0 = rewriter.create<OneResultOp>(
|
||||
|
|
Loading…
Reference in New Issue