diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 8441c7507648..67b6fc11be59 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -898,34 +898,6 @@ bool parseBinaryOp(OpAsmParser *parser, OperationState *result); void printBinaryOp(const OperationInst *op, OpAsmPrinter *p); } // namespace impl -/// This template is used for operations that are simple binary ops that have -/// two input operands, one result, and whose operands and results all have -/// the same type. -/// -/// From this structure, subclasses get a standard builder, parser and printer. -/// -template class... Traits> -class BinaryOp - : public Op::Impl, OpTrait::OneResult, - OpTrait::SameOperandsAndResultType, Traits...> { -public: - static void build(Builder *builder, OperationState *result, Value *lhs, - Value *rhs) { - impl::buildBinaryOp(builder, result, lhs, rhs); - } - static bool parse(OpAsmParser *parser, OperationState *result) { - return impl::parseBinaryOp(parser, result); - } - void print(OpAsmPrinter *p) const { - return impl::printBinaryOp(this->getInstruction(), p); - } - -protected: - explicit BinaryOp(const OperationInst *state) - : Op::Impl, OpTrait::OneResult, - OpTrait::SameOperandsAndResultType, Traits...>(state) {} -}; - // These functions are out-of-line implementations of the methods in CastOp, // which avoids them being template instantiated/duplicated. namespace impl { diff --git a/mlir/include/mlir/IR/op_base.td b/mlir/include/mlir/IR/op_base.td index 369888052e3a..7fe9f7d97828 100644 --- a/mlir/include/mlir/IR/op_base.td +++ b/mlir/include/mlir/IR/op_base.td @@ -208,6 +208,9 @@ class Op props = []> { // and C++ implementations. bit hasCanonicalizationPatterns = 0b0; + // Whether this op has a constant folder. + bit hasConstantFolder = 0b0; + // Op properties. list properties = props; } diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h index abb8b5f821fa..4d8867ffd4d2 100644 --- a/mlir/include/mlir/StandardOps/StandardOps.h +++ b/mlir/include/mlir/StandardOps/StandardOps.h @@ -37,53 +37,8 @@ public: StandardOpsDialect(MLIRContext *context); }; -/// The "addf" operation takes two operands and returns one result, each of -/// these is required to be of the same type. This type may be a floating point -/// scalar type, a vector whose element type is a floating point type, or a -/// floating point tensor. For example: -/// -/// %2 = addf %0, %1 : f32 -/// -class AddFOp - : public BinaryOp { -public: - static void build(Builder *builder, OperationState *result, Value *lhs, - Value *rhs); - - static StringRef getOperationName() { return "addf"; } - - Attribute constantFold(ArrayRef operands, - MLIRContext *context) const; - -private: - friend class OperationInst; - explicit AddFOp(const OperationInst *state) : BinaryOp(state) {} -}; - -/// The "addi" operation takes two operands and returns one result, each of -/// these is required to be of the same type. This type may be an integer -/// scalar type, a vector whose element type is an integer type, or a -/// integer tensor. For example: -/// -/// %2 = addi %0, %1 : i32 -/// -class AddIOp - : public BinaryOp { -public: - static StringRef getOperationName() { return "addi"; } - - Attribute constantFold(ArrayRef operands, - MLIRContext *context) const; - - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - -private: - friend class OperationInst; - explicit AddIOp(const OperationInst *state) : BinaryOp(state) {} -}; +#define GET_OP_CLASSES +#include "mlir/StandardOps/standard_ops.inc" /// The "alloc" operation allocates a region of memory, as specified by its /// memref type. For example: @@ -650,51 +605,6 @@ private: explicit MemRefCastOp(const OperationInst *state) : CastOp(state) {} }; -/// The "mulf" operation takes two operands and returns one result, each of -/// these is required to be of the same type. This type may be a floating point -/// scalar type, a vector whose element type is a floating point type, or a -/// floating point tensor. For example: -/// -/// %2 = mulf %0, %1 : f32 -/// -class MulFOp - : public BinaryOp { -public: - static StringRef getOperationName() { return "mulf"; } - - Attribute constantFold(ArrayRef operands, - MLIRContext *context) const; - -private: - friend class OperationInst; - explicit MulFOp(const OperationInst *state) : BinaryOp(state) {} -}; - -/// The "muli" operation takes two operands and returns one result, each of -/// these is required to be of the same type. This type may be an integer -/// scalar type, a vector whose element type is an integer type, or an -/// integer tensor. For example: -/// -/// %2 = muli %0, %1 : i32 -/// -class MulIOp - : public BinaryOp { -public: - static StringRef getOperationName() { return "muli"; } - - Attribute constantFold(ArrayRef operands, - MLIRContext *context) const; - - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - -private: - friend class OperationInst; - explicit MulIOp(const OperationInst *state) : BinaryOp(state) {} -}; - /// The "select" operation chooses one value based on a binary condition /// supplied as its first operand. If the value of the first operand is 1, the /// second operand is chosen, otherwise the third operand is chosen. The second @@ -784,49 +694,6 @@ private: explicit StoreOp(const OperationInst *state) : Op(state) {} }; -/// The "subf" operation takes two operands and returns one result, each of -/// these is required to be of the same type. This type may be a floating point -/// scalar type, a vector whose element type is a floating point type, or a -/// floating point tensor. For example: -/// -/// %2 = subf %0, %1 : f32 -/// -class SubFOp : public BinaryOp { -public: - static StringRef getOperationName() { return "subf"; } - - Attribute constantFold(ArrayRef operands, - MLIRContext *context) const; - -private: - friend class OperationInst; - explicit SubFOp(const OperationInst *state) : BinaryOp(state) {} -}; - -/// The "subi" operation takes two operands and returns one result, each of -/// these is required to be of the same type. This type may be an integer -/// scalar type, a vector whose element type is an integer type, or a -/// integer tensor. For example: -/// -/// %2 = subi %0, %1 : i32 -/// -class SubIOp : public BinaryOp { -public: - static StringRef getOperationName() { return "subi"; } - - Attribute constantFold(ArrayRef operands, - MLIRContext *context) const; - - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); - -private: - friend class OperationInst; - explicit SubIOp(const OperationInst *state) : BinaryOp(state) {} -}; - /// The "tensor_cast" operation converts a tensor from one type to an equivalent /// type without changing any data elements. The source and destination types /// must both be tensor types with the same element type, and the source and diff --git a/mlir/include/mlir/StandardOps/standard_ops.td b/mlir/include/mlir/StandardOps/standard_ops.td new file mode 100644 index 000000000000..e01a41e417b9 --- /dev/null +++ b/mlir/include/mlir/StandardOps/standard_ops.td @@ -0,0 +1,119 @@ +//===- standard_ops.td - Standard operation definitions ----*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines some MLIR standard operations. +// +//===----------------------------------------------------------------------===// + +#ifdef STANDARD_OPS +#else +#define STANDARD_OPS + +#ifdef OP_BASE +#else +include "mlir/IR/op_base.td" +#endif // OP_BASE + +def AnyType : Type; + +// Base class for standard arithmetic operations. Requires operands and +// results to be of the same type, but does not constrain them to specific +// types. Individual classes will have `lhs` and `rhs` accessor to operands. +class ArithmeticOp props = [], + list traits = []> : + Op, + Traits, + Arguments<(ins AnyType:$lhs, AnyType:$rhs)>, + Results<[AnyType]> { + + let opName = mnemonic; + + let builder = [{ + static void build(Builder *builder, OperationState *result, Value *lhs, + Value *rhs) { + impl::buildBinaryOp(builder, result, lhs, rhs); + } + }]; + + let parser = [{ + return impl::parseBinaryOp(parser, result); + }]; + + let printer = [{ + return impl::printBinaryOp(this->getInstruction(), p); + }]; +} + +// Base class for standard arithmetic operations on integers, vectors and +// tensors thereof. This operation takes two operands and returns one result, +// each of these is required to be of the same type. This type may be an +// integer scalar type, a vector whose element type is an integer type, or an +// integer tensor. The short-hand syntax of the operaton is as follows +// +// i %0, %1 : i32 +class IntArithmeticOp props = [], + list traits = []> : + ArithmeticOp; + +// Base class for standard arithmetic binary operations on floats, vectors and +// tensors thereof. This operation has two operands and returns one result, +// each of these is required to be of the same type. This type may be a +// floating point scalar type, a vector whose element type is a floating point +// type, or a floating point tensor. The short-hand syntax of the operation is +// as follows +// +// f %0, %1 : f32 +class FloatArithmeticOp props = [], + list traits = []> : + ArithmeticOp; + +def AddFOp : FloatArithmeticOp<"addf"> { + let summary = "floating point addition operation"; + let hasConstantFolder = 0b1; +} + +def AddIOp : IntArithmeticOp<"addi", [Commutative]> { + let summary = "integer addition operation"; + let hasCanonicalizationPatterns = 0b1; + let hasConstantFolder = 0b1; +} + +def MulFOp : FloatArithmeticOp<"mulf"> { + let summary = "foating point multiplication operation"; + let hasConstantFolder = 0b1; +} + +def MulIOp : IntArithmeticOp<"muli", [Commutative]> { + let summary = "integer multiplication operation"; + let hasCanonicalizationPatterns = 0b1; + let hasConstantFolder = 0b1; +} + +def SubFOp : FloatArithmeticOp<"subf"> { + let summary = "floating point subtraction operation"; + let hasConstantFolder = 0b1; +} + +def SubIOp : IntArithmeticOp<"subi"> { + let summary = "integer subtraction operation"; + let hasConstantFolder = 0b1; + let hasCanonicalizationPatterns = 0b1; +} + +#endif // STANDARD_OPS diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index fa827661897f..836c5572f5e9 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -37,10 +37,12 @@ using namespace mlir; StandardOpsDialect::StandardOpsDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) { - addOperations(); + addOperations(); } //===----------------------------------------------------------------------===// @@ -78,13 +80,6 @@ struct MemRefCastFolder : public RewritePattern { // AddFOp //===----------------------------------------------------------------------===// -void AddFOp::build(Builder *builder, OperationState *result, Value *lhs, - Value *rhs) { - assert(lhs->getType() == rhs->getType()); - result->addOperands({lhs, rhs}); - result->types.push_back(lhs->getType()); -} - Attribute AddFOp::constantFold(ArrayRef operands, MLIRContext *context) const { assert(operands.size() == 2 && "addf takes two operands"); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index c55931546bf5..8231d6b0bcd1 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -113,6 +113,9 @@ public: // Emit method declaration for the getCanonicalizationPatterns() interface. void emitCanonicalizationPatterns(); + // Emit the constant folder method for the operation. + void emitConstantFolder(); + // Emit the parser for the operation. void emitParser(); @@ -171,6 +174,7 @@ void OpEmitter::emit(const Record &def, raw_ostream &os) { emitter.emitVerifier(); emitter.emitAttrGetters(); emitter.emitCanonicalizationPatterns(); + emitter.emitConstantFolder(); os << "private:\n friend class ::mlir::OperationInst;\n"; os << " explicit " << emitter.op.cppClassName() @@ -321,6 +325,19 @@ void OpEmitter::emitCanonicalizationPatterns() { << "OwningRewritePatternList &results, MLIRContext* context);\n"; } +void OpEmitter::emitConstantFolder() { + if (!def.getValueAsBit("hasConstantFolder")) + return; + if (def.getValueAsListOfDefs("returnTypes").size() == 1) { + os << " Attribute constantFold(ArrayRef operands,\n" + " MLIRContext *context) const;\n"; + } else { + os << " bool constantFold(ArrayRef operands,\n" + << " SmallVectorImpl &results,\n" + << " MLIRContext *context) const;\n"; + } +} + void OpEmitter::emitParser() { if (!hasStringAttribute(def, "parser")) return;