2018-10-11 05:23:30 +08:00
|
|
|
//===- BuiltinOps.cpp - Builtin MLIR Operations -------------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
|
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
|
|
#include "mlir/IR/AffineExpr.h"
|
|
|
|
#include "mlir/IR/AffineMap.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include "mlir/IR/OpImplementation.h"
|
|
|
|
#include "mlir/IR/SSAValue.h"
|
|
|
|
#include "mlir/IR/Types.h"
|
|
|
|
#include "mlir/Support/MathExtras.h"
|
|
|
|
#include "mlir/Support/STLExtras.h"
|
|
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
using namespace mlir;
|
|
|
|
|
2018-10-22 10:49:31 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// BuiltinDialect
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
BuiltinDialect::BuiltinDialect(MLIRContext *context)
|
|
|
|
: Dialect(/*opPrefix=*/"", context) {
|
|
|
|
addOperations<AffineApplyOp, ConstantOp, ReturnOp>();
|
|
|
|
}
|
|
|
|
|
2018-10-11 05:23:30 +08:00
|
|
|
void mlir::printDimAndSymbolList(Operation::const_operand_iterator begin,
|
|
|
|
Operation::const_operand_iterator end,
|
|
|
|
unsigned numDims, OpAsmPrinter *p) {
|
|
|
|
*p << '(';
|
|
|
|
p->printOperands(begin, begin + numDims);
|
|
|
|
*p << ')';
|
|
|
|
|
|
|
|
if (begin + numDims != end) {
|
|
|
|
*p << '[';
|
|
|
|
p->printOperands(begin + numDims, end);
|
|
|
|
*p << ']';
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Parses dimension and symbol list, and sets 'numDims' to the number of
|
|
|
|
// dimension operands parsed.
|
|
|
|
// Returns 'false' on success and 'true' on error.
|
|
|
|
bool mlir::parseDimAndSymbolList(OpAsmParser *parser,
|
|
|
|
SmallVector<SSAValue *, 4> &operands,
|
|
|
|
unsigned &numDims) {
|
|
|
|
SmallVector<OpAsmParser::OperandType, 8> opInfos;
|
|
|
|
if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
|
|
|
|
return true;
|
|
|
|
// Store number of dimensions for validation by caller.
|
|
|
|
numDims = opInfos.size();
|
|
|
|
|
|
|
|
// Parse the optional symbol operands.
|
|
|
|
auto *affineIntTy = parser->getBuilder().getIndexType();
|
|
|
|
if (parser->parseOperandList(opInfos, -1,
|
|
|
|
OpAsmParser::Delimiter::OptionalSquare) ||
|
|
|
|
parser->resolveOperands(opInfos, affineIntTy, operands))
|
|
|
|
return true;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// AffineApplyOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void AffineApplyOp::build(Builder *builder, OperationState *result,
|
|
|
|
AffineMap map, ArrayRef<SSAValue *> operands) {
|
|
|
|
result->addOperands(operands);
|
|
|
|
result->types.append(map.getNumResults(), builder->getIndexType());
|
|
|
|
result->addAttribute("map", builder->getAffineMapAttr(map));
|
|
|
|
}
|
|
|
|
|
|
|
|
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
auto &builder = parser->getBuilder();
|
|
|
|
auto *affineIntTy = builder.getIndexType();
|
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
AffineMapAttr mapAttr;
|
2018-10-11 05:23:30 +08:00
|
|
|
unsigned numDims;
|
|
|
|
if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
|
|
|
|
parseDimAndSymbolList(parser, result->operands, numDims) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes))
|
|
|
|
return true;
|
2018-10-26 06:46:10 +08:00
|
|
|
auto map = mapAttr.getValue();
|
2018-10-11 05:23:30 +08:00
|
|
|
|
|
|
|
if (map.getNumDims() != numDims ||
|
|
|
|
numDims + map.getNumSymbols() != result->operands.size()) {
|
|
|
|
return parser->emitError(parser->getNameLoc(),
|
|
|
|
"dimension or symbol index mismatch");
|
|
|
|
}
|
|
|
|
|
|
|
|
result->types.append(map.getNumResults(), affineIntTy);
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
void AffineApplyOp::print(OpAsmPrinter *p) const {
|
|
|
|
auto map = getAffineMap();
|
|
|
|
*p << "affine_apply " << map;
|
|
|
|
printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p);
|
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
|
|
|
|
}
|
|
|
|
|
|
|
|
bool AffineApplyOp::verify() const {
|
|
|
|
// Check that affine map attribute was specified.
|
2018-10-26 06:46:10 +08:00
|
|
|
auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
|
2018-10-11 05:23:30 +08:00
|
|
|
if (!affineMapAttr)
|
|
|
|
return emitOpError("requires an affine map");
|
|
|
|
|
|
|
|
// Check input and output dimensions match.
|
2018-10-26 06:46:10 +08:00
|
|
|
auto map = affineMapAttr.getValue();
|
2018-10-11 05:23:30 +08:00
|
|
|
|
|
|
|
// Verify that operand count matches affine map dimension and symbol count.
|
|
|
|
if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
|
|
|
|
return emitOpError(
|
|
|
|
"operand count and affine map dimension and symbol count must match");
|
|
|
|
|
|
|
|
// Verify that result count matches affine map result count.
|
|
|
|
if (getNumResults() != map.getNumResults())
|
|
|
|
return emitOpError("result count and affine map result count must match");
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
// The result of the affine apply operation can be used as a dimension id if it
|
|
|
|
// is a CFG value or if it is an MLValue, and all the operands are valid
|
|
|
|
// dimension ids.
|
|
|
|
bool AffineApplyOp::isValidDim() const {
|
|
|
|
for (auto *op : getOperands()) {
|
|
|
|
if (auto *v = dyn_cast<MLValue>(op))
|
|
|
|
if (!v->isValidDim())
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
// The result of the affine apply operation can be used as a symbol if it is
|
|
|
|
// a CFG value or if it is an MLValue, and all the operands are symbols.
|
|
|
|
bool AffineApplyOp::isValidSymbol() const {
|
|
|
|
for (auto *op : getOperands()) {
|
|
|
|
if (auto *v = dyn_cast<MLValue>(op))
|
|
|
|
if (!v->isValidSymbol())
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
|
|
|
|
SmallVectorImpl<Attribute> &results,
|
2018-10-11 05:23:30 +08:00
|
|
|
MLIRContext *context) const {
|
|
|
|
auto map = getAffineMap();
|
|
|
|
if (map.constantFold(operandConstants, results))
|
|
|
|
return true;
|
|
|
|
// Return false on success.
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Constant*Op
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// Builds a constant op with the specified attribute value and result type.
|
|
|
|
void ConstantOp::build(Builder *builder, OperationState *result,
|
2018-10-26 06:46:10 +08:00
|
|
|
Attribute value, Type *type) {
|
2018-10-11 05:23:30 +08:00
|
|
|
result->addAttribute("value", value);
|
|
|
|
result->types.push_back(type);
|
|
|
|
}
|
|
|
|
|
|
|
|
void ConstantOp::print(OpAsmPrinter *p) const {
|
2018-10-26 06:46:10 +08:00
|
|
|
*p << "constant " << getValue();
|
2018-10-11 05:23:30 +08:00
|
|
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
|
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
if (!getValue().isa<FunctionAttr>())
|
2018-10-11 05:23:30 +08:00
|
|
|
*p << " : " << *getType();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
|
2018-10-26 06:46:10 +08:00
|
|
|
Attribute valueAttr;
|
2018-10-11 05:23:30 +08:00
|
|
|
Type *type;
|
|
|
|
|
|
|
|
if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes))
|
|
|
|
return true;
|
|
|
|
|
|
|
|
// 'constant' taking a function reference doesn't get a redundant type
|
|
|
|
// specifier. The attribute itself carries it.
|
2018-10-26 06:46:10 +08:00
|
|
|
if (auto fnAttr = valueAttr.dyn_cast<FunctionAttr>())
|
|
|
|
return parser->addTypeToList(fnAttr.getValue()->getType(), result->types);
|
2018-10-11 05:23:30 +08:00
|
|
|
|
|
|
|
return parser->parseColonType(type) ||
|
|
|
|
parser->addTypeToList(type, result->types);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// The constant op requires an attribute, and furthermore requires that it
|
|
|
|
/// matches the return type.
|
|
|
|
bool ConstantOp::verify() const {
|
2018-10-26 06:46:10 +08:00
|
|
|
auto value = getValue();
|
2018-10-11 05:23:30 +08:00
|
|
|
if (!value)
|
|
|
|
return emitOpError("requires a 'value' attribute");
|
|
|
|
|
|
|
|
auto *type = this->getType();
|
|
|
|
if (isa<IntegerType>(type) || type->isIndex()) {
|
2018-10-26 06:46:10 +08:00
|
|
|
if (!value.isa<IntegerAttr>())
|
2018-10-11 05:23:30 +08:00
|
|
|
return emitOpError(
|
|
|
|
"requires 'value' to be an integer for an integer result type");
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (isa<FloatType>(type)) {
|
2018-10-26 06:46:10 +08:00
|
|
|
if (!value.isa<FloatAttr>())
|
2018-10-11 05:23:30 +08:00
|
|
|
return emitOpError("requires 'value' to be a floating point constant");
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (type->isTFString()) {
|
2018-10-26 06:46:10 +08:00
|
|
|
if (!value.isa<StringAttr>())
|
2018-10-11 05:23:30 +08:00
|
|
|
return emitOpError("requires 'value' to be a string constant");
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (isa<FunctionType>(type)) {
|
2018-10-26 06:46:10 +08:00
|
|
|
if (!value.isa<FunctionAttr>())
|
2018-10-11 05:23:30 +08:00
|
|
|
return emitOpError("requires 'value' to be a function reference");
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
return emitOpError(
|
|
|
|
"requires a result type that aligns with the 'value' attribute");
|
|
|
|
}
|
|
|
|
|
2018-10-26 06:46:10 +08:00
|
|
|
Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
MLIRContext *context) const {
|
2018-10-11 05:23:30 +08:00
|
|
|
assert(operands.empty() && "constant has no operands");
|
|
|
|
return getValue();
|
|
|
|
}
|
|
|
|
|
|
|
|
void ConstantFloatOp::build(Builder *builder, OperationState *result,
|
2018-10-21 09:31:49 +08:00
|
|
|
const APFloat &value, FloatType *type) {
|
2018-10-11 05:23:30 +08:00
|
|
|
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
|
|
|
|
}
|
|
|
|
|
|
|
|
bool ConstantFloatOp::isClassFor(const Operation *op) {
|
|
|
|
return ConstantOp::isClassFor(op) &&
|
|
|
|
isa<FloatType>(op->getResult(0)->getType());
|
|
|
|
}
|
|
|
|
|
|
|
|
/// ConstantIntOp only matches values whose result type is an IntegerType.
|
|
|
|
bool ConstantIntOp::isClassFor(const Operation *op) {
|
|
|
|
return ConstantOp::isClassFor(op) &&
|
|
|
|
isa<IntegerType>(op->getResult(0)->getType());
|
|
|
|
}
|
|
|
|
|
|
|
|
void ConstantIntOp::build(Builder *builder, OperationState *result,
|
|
|
|
int64_t value, unsigned width) {
|
|
|
|
ConstantOp::build(builder, result, builder->getIntegerAttr(value),
|
|
|
|
builder->getIntegerType(width));
|
|
|
|
}
|
|
|
|
|
2018-10-12 08:21:55 +08:00
|
|
|
/// Build a constant int op producing an integer with the specified type,
|
|
|
|
/// which must be an integer type.
|
|
|
|
void ConstantIntOp::build(Builder *builder, OperationState *result,
|
|
|
|
int64_t value, Type *type) {
|
|
|
|
assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type");
|
|
|
|
ConstantOp::build(builder, result, builder->getIntegerAttr(value), type);
|
|
|
|
}
|
|
|
|
|
2018-10-11 05:23:30 +08:00
|
|
|
/// ConstantIndexOp only matches values whose result type is Index.
|
|
|
|
bool ConstantIndexOp::isClassFor(const Operation *op) {
|
|
|
|
return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex();
|
|
|
|
}
|
|
|
|
|
|
|
|
void ConstantIndexOp::build(Builder *builder, OperationState *result,
|
|
|
|
int64_t value) {
|
|
|
|
ConstantOp::build(builder, result, builder->getIntegerAttr(value),
|
|
|
|
builder->getIndexType());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ReturnOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void ReturnOp::build(Builder *builder, OperationState *result,
|
|
|
|
ArrayRef<SSAValue *> results) {
|
|
|
|
result->addOperands(results);
|
|
|
|
}
|
|
|
|
|
|
|
|
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
SmallVector<OpAsmParser::OperandType, 2> opInfo;
|
|
|
|
SmallVector<Type *, 2> types;
|
|
|
|
llvm::SMLoc loc;
|
|
|
|
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
|
|
|
|
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
|
|
|
|
parser->resolveOperands(opInfo, types, loc, result->operands);
|
|
|
|
}
|
|
|
|
|
|
|
|
void ReturnOp::print(OpAsmPrinter *p) const {
|
|
|
|
*p << "return";
|
|
|
|
if (getNumOperands() > 0) {
|
|
|
|
*p << ' ';
|
|
|
|
p->printOperands(operand_begin(), operand_end());
|
|
|
|
*p << " : ";
|
|
|
|
interleave(operand_begin(), operand_end(),
|
|
|
|
[&](const SSAValue *e) { p->printType(e->getType()); },
|
|
|
|
[&]() { *p << ", "; });
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
bool ReturnOp::verify() const {
|
|
|
|
// ReturnOp must be part of an ML function.
|
|
|
|
if (auto *stmt = dyn_cast<OperationStmt>(getOperation())) {
|
|
|
|
StmtBlock *block = stmt->getBlock();
|
|
|
|
if (!block || !isa<MLFunction>(block) || &block->back() != stmt)
|
|
|
|
return emitOpError("must be the last statement in the ML function");
|
|
|
|
|
2018-10-23 23:46:26 +08:00
|
|
|
// The operand number and types must match the function signature.
|
|
|
|
MLFunction *function = cast<MLFunction>(block);
|
|
|
|
const auto &results = function->getType()->getResults();
|
|
|
|
if (stmt->getNumOperands() != results.size())
|
|
|
|
return emitOpError("has " + Twine(stmt->getNumOperands()) +
|
|
|
|
" operands, but enclosing function returns " +
|
|
|
|
Twine(results.size()));
|
|
|
|
|
|
|
|
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
|
|
|
if (stmt->getOperand(i)->getType() != results[i]) {
|
|
|
|
emitError("type of return operand " + Twine(i) +
|
|
|
|
" doesn't match function result type");
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2018-10-11 05:23:30 +08:00
|
|
|
// Return success. Checking that operand types match those in the function
|
|
|
|
// signature is performed in the ML function verifier.
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return emitOpError("cannot occur in a CFG function");
|
|
|
|
}
|