2018-07-05 11:45:39 +08:00
|
|
|
//===- Operation.cpp - MLIR Operation Class -------------------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
|
|
|
|
#include "mlir/IR/Operation.h"
|
2018-07-06 12:20:59 +08:00
|
|
|
#include "AttributeListStorage.h"
|
2018-08-21 23:42:19 +08:00
|
|
|
#include "mlir/IR/CFGFunction.h"
|
2018-11-10 06:04:03 +08:00
|
|
|
#include "mlir/IR/Dialect.h"
|
2018-07-23 12:02:26 +08:00
|
|
|
#include "mlir/IR/Instructions.h"
|
2018-08-21 23:42:19 +08:00
|
|
|
#include "mlir/IR/MLFunction.h"
|
2018-08-02 01:18:59 +08:00
|
|
|
#include "mlir/IR/MLIRContext.h"
|
2018-09-10 11:40:23 +08:00
|
|
|
#include "mlir/IR/OpDefinition.h"
|
2018-09-27 01:07:16 +08:00
|
|
|
#include "mlir/IR/OpImplementation.h"
|
2018-07-23 12:02:26 +08:00
|
|
|
#include "mlir/IR/Statements.h"
|
2018-11-10 06:04:03 +08:00
|
|
|
|
2018-07-05 11:45:39 +08:00
|
|
|
using namespace mlir;
|
|
|
|
|
2018-10-10 13:08:52 +08:00
|
|
|
/// Form the OperationName for an op with the specified string. This either is
|
|
|
|
/// a reference to an AbstractOperation if one is known, or a uniqued Identifier
|
|
|
|
/// if not.
|
|
|
|
OperationName::OperationName(StringRef name, MLIRContext *context) {
|
2018-10-22 10:49:31 +08:00
|
|
|
if (auto *op = AbstractOperation::lookup(name, context))
|
2018-10-10 13:08:52 +08:00
|
|
|
representation = op;
|
|
|
|
else
|
|
|
|
representation = Identifier::get(name, context);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Return the name of this operation. This always succeeds.
|
|
|
|
StringRef OperationName::getStringRef() const {
|
|
|
|
if (auto *op = representation.dyn_cast<const AbstractOperation *>())
|
|
|
|
return op->name;
|
|
|
|
return representation.get<Identifier>().strref();
|
|
|
|
}
|
|
|
|
|
|
|
|
const AbstractOperation *OperationName::getAbstractOperation() const {
|
|
|
|
return representation.dyn_cast<const AbstractOperation *>();
|
|
|
|
}
|
|
|
|
|
|
|
|
OperationName OperationName::getFromOpaquePointer(void *pointer) {
|
|
|
|
return OperationName(RepresentationUnion::getFromOpaqueValue(pointer));
|
|
|
|
}
|
|
|
|
|
2018-10-22 10:49:31 +08:00
|
|
|
OpAsmParser::~OpAsmParser() {}
|
|
|
|
|
2018-10-10 13:08:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Operation class
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
Operation::Operation(bool isInstruction, OperationName name,
|
2018-07-19 10:06:45 +08:00
|
|
|
ArrayRef<NamedAttribute> attrs, MLIRContext *context)
|
|
|
|
: nameAndIsInstruction(name, isInstruction) {
|
2018-07-06 12:20:59 +08:00
|
|
|
this->attrs = AttributeListStorage::get(attrs, context);
|
|
|
|
|
2018-07-05 11:45:39 +08:00
|
|
|
#ifndef NDEBUG
|
|
|
|
for (auto elt : attrs)
|
|
|
|
assert(elt.second != nullptr && "Attributes cannot have null entries");
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
2018-07-23 12:02:26 +08:00
|
|
|
Operation::~Operation() {}
|
|
|
|
|
2018-08-02 01:18:59 +08:00
|
|
|
/// Return the context this operation is associated with.
|
|
|
|
MLIRContext *Operation::getContext() const {
|
2018-11-16 10:31:15 +08:00
|
|
|
if (auto *inst = llvm::dyn_cast<Instruction>(this))
|
2018-08-02 01:18:59 +08:00
|
|
|
return inst->getContext();
|
2018-10-20 00:07:58 +08:00
|
|
|
return llvm::cast<OperationStmt>(this)->getContext();
|
2018-08-02 01:18:59 +08:00
|
|
|
}
|
|
|
|
|
2018-08-24 05:32:25 +08:00
|
|
|
/// The source location the operation was defined or derived from. Note that
|
|
|
|
/// it is possible for this pointer to be null.
|
2018-11-09 04:28:35 +08:00
|
|
|
Location Operation::getLoc() const {
|
2018-11-16 10:31:15 +08:00
|
|
|
if (auto *inst = llvm::dyn_cast<Instruction>(this))
|
2018-08-24 05:32:25 +08:00
|
|
|
return inst->getLoc();
|
2018-10-20 00:07:58 +08:00
|
|
|
return llvm::cast<OperationStmt>(this)->getLoc();
|
2018-08-24 05:32:25 +08:00
|
|
|
}
|
|
|
|
|
2018-08-21 23:42:19 +08:00
|
|
|
/// Return the function this operation is defined in.
|
|
|
|
Function *Operation::getOperationFunction() {
|
2018-11-16 10:31:15 +08:00
|
|
|
if (auto *inst = llvm::dyn_cast<Instruction>(this))
|
2018-08-21 23:42:19 +08:00
|
|
|
return inst->getFunction();
|
2018-10-20 00:07:58 +08:00
|
|
|
return llvm::cast<OperationStmt>(this)->findFunction();
|
2018-08-21 23:42:19 +08:00
|
|
|
}
|
|
|
|
|
2018-07-23 12:02:26 +08:00
|
|
|
/// Return the number of operands this operation has.
|
|
|
|
unsigned Operation::getNumOperands() const {
|
2018-11-16 10:31:15 +08:00
|
|
|
if (auto *inst = llvm::dyn_cast<Instruction>(this))
|
2018-07-23 12:02:26 +08:00
|
|
|
return inst->getNumOperands();
|
2018-08-21 23:42:19 +08:00
|
|
|
|
2018-10-20 00:07:58 +08:00
|
|
|
return llvm::cast<OperationStmt>(this)->getNumOperands();
|
2018-07-23 12:02:26 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
SSAValue *Operation::getOperand(unsigned idx) {
|
2018-11-16 10:31:15 +08:00
|
|
|
if (auto *inst = llvm::dyn_cast<Instruction>(this))
|
2018-07-23 12:02:26 +08:00
|
|
|
return inst->getOperand(idx);
|
2018-08-21 23:42:19 +08:00
|
|
|
|
2018-10-20 00:07:58 +08:00
|
|
|
return llvm::cast<OperationStmt>(this)->getOperand(idx);
|
2018-07-23 12:02:26 +08:00
|
|
|
}
|
|
|
|
|
2018-07-24 01:08:00 +08:00
|
|
|
void Operation::setOperand(unsigned idx, SSAValue *value) {
|
2018-11-16 10:31:15 +08:00
|
|
|
if (auto *inst = llvm::dyn_cast<Instruction>(this)) {
|
2018-10-20 00:07:58 +08:00
|
|
|
inst->setOperand(idx, llvm::cast<CFGValue>(value));
|
2018-07-24 01:08:00 +08:00
|
|
|
} else {
|
2018-10-20 00:07:58 +08:00
|
|
|
auto *stmt = llvm::cast<OperationStmt>(this);
|
|
|
|
stmt->setOperand(idx, llvm::cast<MLValue>(value));
|
2018-07-24 01:08:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-07-23 12:02:26 +08:00
|
|
|
/// Return the number of results this operation has.
|
|
|
|
unsigned Operation::getNumResults() const {
|
2018-11-16 10:31:15 +08:00
|
|
|
if (auto *inst = llvm::dyn_cast<Instruction>(this))
|
2018-07-23 12:02:26 +08:00
|
|
|
return inst->getNumResults();
|
2018-08-21 23:42:19 +08:00
|
|
|
|
2018-10-20 00:07:58 +08:00
|
|
|
return llvm::cast<OperationStmt>(this)->getNumResults();
|
2018-07-23 12:02:26 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Return the indicated result.
|
|
|
|
SSAValue *Operation::getResult(unsigned idx) {
|
2018-11-16 10:31:15 +08:00
|
|
|
if (auto *inst = llvm::dyn_cast<Instruction>(this))
|
2018-07-23 12:02:26 +08:00
|
|
|
return inst->getResult(idx);
|
2018-08-21 23:42:19 +08:00
|
|
|
|
2018-10-20 00:07:58 +08:00
|
|
|
return llvm::cast<OperationStmt>(this)->getResult(idx);
|
2018-07-05 11:45:39 +08:00
|
|
|
}
|
|
|
|
|
2018-11-16 01:56:06 +08:00
|
|
|
unsigned Operation::getNumSuccessors() const {
|
|
|
|
assert(isTerminator() && "Only terminators have successors.");
|
2018-11-16 10:31:15 +08:00
|
|
|
if (llvm::isa<Instruction>(this))
|
|
|
|
return llvm::cast<Instruction>(this)->getNumSuccessors();
|
2018-11-16 01:56:06 +08:00
|
|
|
|
|
|
|
// OperationStmt currently only has a return terminator.
|
2018-11-16 03:37:33 +08:00
|
|
|
assert(llvm::cast<OperationStmt>(this)->isReturn() &&
|
|
|
|
"Unhandled OperationStmt terminator.");
|
2018-11-16 01:56:06 +08:00
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
unsigned Operation::getNumSuccessorOperands(unsigned index) const {
|
|
|
|
assert(isTerminator() && "Only terminators have successors.");
|
2018-11-16 10:31:15 +08:00
|
|
|
assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
|
|
|
|
return llvm::cast<Instruction>(this)->getNumSuccessorOperands(index);
|
2018-11-16 01:56:06 +08:00
|
|
|
}
|
|
|
|
BasicBlock *Operation::getSuccessor(unsigned index) {
|
|
|
|
assert(isTerminator() && "Only terminators have successors.");
|
2018-11-16 10:31:15 +08:00
|
|
|
assert(llvm::isa<Instruction>(this) &&
|
2018-11-16 01:56:06 +08:00
|
|
|
"Only instructions have basic block successors.");
|
2018-11-16 10:31:15 +08:00
|
|
|
return llvm::cast<Instruction>(this)->getSuccessor(index);
|
2018-11-16 01:56:06 +08:00
|
|
|
}
|
|
|
|
void Operation::setSuccessor(BasicBlock *block, unsigned index) {
|
|
|
|
assert(isTerminator() && "Only terminators have successors.");
|
2018-11-16 10:31:15 +08:00
|
|
|
assert(llvm::isa<Instruction>(this) &&
|
2018-11-16 01:56:06 +08:00
|
|
|
"Only instructions have basic block successors.");
|
2018-11-16 10:31:15 +08:00
|
|
|
llvm::cast<Instruction>(this)->setSuccessor(block, index);
|
2018-11-16 01:56:06 +08:00
|
|
|
}
|
|
|
|
void Operation::addSuccessorOperand(unsigned index, SSAValue *value) {
|
|
|
|
assert(isTerminator() && "Only terminators have successors.");
|
2018-11-16 10:31:15 +08:00
|
|
|
assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
|
|
|
|
return llvm::cast<Instruction>(this)->addSuccessorOperand(
|
2018-11-16 01:56:06 +08:00
|
|
|
index, llvm::cast<CFGValue>(value));
|
|
|
|
}
|
2018-11-21 06:03:41 +08:00
|
|
|
void Operation::eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) {
|
|
|
|
assert(isTerminator() && "Only terminators have successors.");
|
|
|
|
assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
|
|
|
|
return llvm::cast<Instruction>(this)->eraseSuccessorOperand(succIndex,
|
|
|
|
opIndex);
|
|
|
|
}
|
2018-11-16 01:56:06 +08:00
|
|
|
auto Operation::getSuccessorOperands(unsigned index) const
|
|
|
|
-> llvm::iterator_range<const_operand_iterator> {
|
|
|
|
assert(isTerminator() && "Only terminators have successors.");
|
2018-11-16 10:31:15 +08:00
|
|
|
assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
|
2018-11-16 01:56:06 +08:00
|
|
|
unsigned succOperandIndex =
|
2018-11-16 10:31:15 +08:00
|
|
|
llvm::cast<Instruction>(this)->getSuccessorOperandIndex(index);
|
2018-11-16 01:56:06 +08:00
|
|
|
return {const_operand_iterator(this, succOperandIndex),
|
|
|
|
const_operand_iterator(this, succOperandIndex +
|
|
|
|
getNumSuccessorOperands(index))};
|
|
|
|
}
|
|
|
|
auto Operation::getSuccessorOperands(unsigned index)
|
|
|
|
-> llvm::iterator_range<operand_iterator> {
|
|
|
|
assert(isTerminator() && "Only terminators have successors.");
|
2018-11-16 10:31:15 +08:00
|
|
|
assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
|
2018-11-16 01:56:06 +08:00
|
|
|
unsigned succOperandIndex =
|
2018-11-16 10:31:15 +08:00
|
|
|
llvm::cast<Instruction>(this)->getSuccessorOperandIndex(index);
|
2018-11-16 01:56:06 +08:00
|
|
|
return {operand_iterator(this, succOperandIndex),
|
|
|
|
operand_iterator(this,
|
|
|
|
succOperandIndex + getNumSuccessorOperands(index))};
|
|
|
|
}
|
|
|
|
|
2018-10-17 00:31:45 +08:00
|
|
|
/// Return true if there are no users of any results of this operation.
|
|
|
|
bool Operation::use_empty() const {
|
|
|
|
for (auto *result : getResults())
|
|
|
|
if (!result->use_empty())
|
|
|
|
return false;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2018-11-21 00:48:08 +08:00
|
|
|
void Operation::moveBefore(Operation *existingOp) {
|
|
|
|
if (auto *inst = llvm::dyn_cast<Instruction>(this))
|
|
|
|
return inst->moveBefore(llvm::cast<Instruction>(existingOp));
|
|
|
|
return llvm::cast<OperationStmt>(this)->moveBefore(
|
|
|
|
llvm::cast<OperationStmt>(existingOp));
|
|
|
|
}
|
|
|
|
|
2018-07-06 12:20:59 +08:00
|
|
|
ArrayRef<NamedAttribute> Operation::getAttrs() const {
|
|
|
|
if (!attrs)
|
|
|
|
return {};
|
|
|
|
return attrs->getElements();
|
|
|
|
}
|
|
|
|
|
2018-07-06 00:12:11 +08:00
|
|
|
/// If an attribute exists with the specified name, change it to the new
|
|
|
|
/// value. Otherwise, add a new attribute with the specified name/value.
|
2018-10-26 06:46:10 +08:00
|
|
|
void Operation::setAttr(Identifier name, Attribute value) {
|
2018-07-06 00:12:11 +08:00
|
|
|
assert(value && "attributes may never be null");
|
2018-07-06 12:20:59 +08:00
|
|
|
auto origAttrs = getAttrs();
|
|
|
|
|
|
|
|
SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
|
2018-08-06 12:12:29 +08:00
|
|
|
auto *context = getContext();
|
2018-07-06 12:20:59 +08:00
|
|
|
|
2018-07-06 00:12:11 +08:00
|
|
|
// If we already have this attribute, replace it.
|
2018-07-06 12:20:59 +08:00
|
|
|
for (auto &elt : newAttrs)
|
2018-07-06 00:12:11 +08:00
|
|
|
if (elt.first == name) {
|
|
|
|
elt.second = value;
|
2018-07-06 12:20:59 +08:00
|
|
|
attrs = AttributeListStorage::get(newAttrs, context);
|
2018-07-06 00:12:11 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Otherwise, add it.
|
2018-07-06 12:20:59 +08:00
|
|
|
newAttrs.push_back({name, value});
|
|
|
|
attrs = AttributeListStorage::get(newAttrs, context);
|
2018-07-06 00:12:11 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Remove the attribute with the specified name if it exists. The return
|
|
|
|
/// value indicates whether the attribute was present or not.
|
2018-08-06 12:12:29 +08:00
|
|
|
auto Operation::removeAttr(Identifier name) -> RemoveResult {
|
2018-07-06 12:20:59 +08:00
|
|
|
auto origAttrs = getAttrs();
|
|
|
|
for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
|
|
|
|
if (origAttrs[i].first == name) {
|
|
|
|
SmallVector<NamedAttribute, 8> newAttrs;
|
|
|
|
newAttrs.reserve(origAttrs.size() - 1);
|
|
|
|
newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
|
|
|
|
newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
|
2018-08-06 12:12:29 +08:00
|
|
|
attrs = AttributeListStorage::get(newAttrs, getContext());
|
2018-07-05 11:45:39 +08:00
|
|
|
return RemoveResult::Removed;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return RemoveResult::NotFound;
|
|
|
|
}
|
2018-08-02 01:18:59 +08:00
|
|
|
|
2018-08-06 12:12:29 +08:00
|
|
|
/// Emit a note about this operation, reporting up to any diagnostic
|
|
|
|
/// handlers that may be listening.
|
|
|
|
void Operation::emitNote(const Twine &message) const {
|
2018-08-24 05:32:25 +08:00
|
|
|
getContext()->emitDiagnostic(getLoc(), message,
|
2018-08-06 12:12:29 +08:00
|
|
|
MLIRContext::DiagnosticKind::Note);
|
|
|
|
}
|
|
|
|
|
2018-08-02 01:18:59 +08:00
|
|
|
/// Emit a warning about this operation, reporting up to any diagnostic
|
|
|
|
/// handlers that may be listening.
|
|
|
|
void Operation::emitWarning(const Twine &message) const {
|
2018-08-24 05:32:25 +08:00
|
|
|
getContext()->emitDiagnostic(getLoc(), message,
|
2018-08-06 12:12:29 +08:00
|
|
|
MLIRContext::DiagnosticKind::Warning);
|
2018-08-02 01:18:59 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Emit an error about fatal conditions with this operation, reporting up to
|
|
|
|
/// any diagnostic handlers that may be listening. NOTE: This may terminate
|
|
|
|
/// the containing application, only use when the IR is in an inconsistent
|
|
|
|
/// state.
|
|
|
|
void Operation::emitError(const Twine &message) const {
|
2018-08-24 05:32:25 +08:00
|
|
|
getContext()->emitDiagnostic(getLoc(), message,
|
2018-08-06 12:12:29 +08:00
|
|
|
MLIRContext::DiagnosticKind::Error);
|
2018-08-02 01:18:59 +08:00
|
|
|
}
|
2018-09-10 11:40:23 +08:00
|
|
|
|
|
|
|
/// Emit an error with the op name prefixed, like "'dim' op " which is
|
|
|
|
/// convenient for verifiers.
|
|
|
|
bool Operation::emitOpError(const Twine &message) const {
|
2018-10-10 13:08:52 +08:00
|
|
|
emitError(Twine('\'') + getName().getStringRef() + "' op " + message);
|
2018-09-10 11:40:23 +08:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2018-10-22 10:53:10 +08:00
|
|
|
/// Remove this operation from its parent block and delete it.
|
|
|
|
void Operation::erase() {
|
2018-11-16 10:31:15 +08:00
|
|
|
if (auto *inst = llvm::dyn_cast<Instruction>(this))
|
2018-10-22 10:53:10 +08:00
|
|
|
return inst->erase();
|
|
|
|
return llvm::cast<OperationStmt>(this)->erase();
|
|
|
|
}
|
|
|
|
|
2018-09-20 12:35:11 +08:00
|
|
|
/// Attempt to constant fold this operation with the specified constant
|
|
|
|
/// operand values. If successful, this returns false and fills in the
|
|
|
|
/// results vector. If not, this returns true and results is unspecified.
|
2018-10-26 06:46:10 +08:00
|
|
|
bool Operation::constantFold(ArrayRef<Attribute> operands,
|
|
|
|
SmallVectorImpl<Attribute> &results) const {
|
2018-11-10 06:04:03 +08:00
|
|
|
if (auto *abstractOp = getAbstractOperation()) {
|
|
|
|
// If we have a registered operation definition matching this one, use it to
|
|
|
|
// try to constant fold the operation.
|
2018-09-20 12:35:11 +08:00
|
|
|
if (!abstractOp->constantFoldHook(this, operands, results))
|
|
|
|
return false;
|
|
|
|
|
2018-11-10 06:04:03 +08:00
|
|
|
// Otherwise, fall back on the dialect hook to handle it.
|
2018-11-21 06:47:10 +08:00
|
|
|
return abstractOp->dialect.constantFoldHook(this, operands, results);
|
2018-11-10 06:04:03 +08:00
|
|
|
}
|
2018-11-21 06:47:10 +08:00
|
|
|
|
|
|
|
// If this operation hasn't been registered or doesn't have abstract
|
|
|
|
// operation, fall back to a dialect which matches the prefix.
|
|
|
|
auto opName = getName().getStringRef();
|
|
|
|
if (auto *dialect = getContext()->getRegisteredDialect(opName)) {
|
|
|
|
return dialect->constantFoldHook(this, operands, results);
|
|
|
|
}
|
|
|
|
|
2018-09-20 12:35:11 +08:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2018-10-29 01:03:19 +08:00
|
|
|
void Operation::print(raw_ostream &os) const {
|
2018-11-16 10:31:15 +08:00
|
|
|
if (auto *inst = llvm::dyn_cast<Instruction>(this))
|
2018-10-29 01:03:19 +08:00
|
|
|
return inst->print(os);
|
|
|
|
return llvm::cast<OperationStmt>(this)->print(os);
|
|
|
|
}
|
|
|
|
|
|
|
|
void Operation::dump() const {
|
2018-11-16 10:31:15 +08:00
|
|
|
if (auto *inst = llvm::dyn_cast<Instruction>(this))
|
2018-10-29 01:03:19 +08:00
|
|
|
return inst->dump();
|
|
|
|
return llvm::cast<OperationStmt>(this)->dump();
|
|
|
|
}
|
|
|
|
|
2018-10-23 04:08:27 +08:00
|
|
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
|
|
|
bool Operation::classof(const Statement *stmt) {
|
|
|
|
return stmt->getKind() == Statement::Kind::Operation;
|
|
|
|
}
|
2018-10-29 01:03:19 +08:00
|
|
|
bool Operation::classof(const IROperandOwner *ptr) {
|
2018-11-16 10:31:15 +08:00
|
|
|
return ptr->getKind() == IROperandOwner::Kind::Instruction ||
|
2018-10-29 01:03:19 +08:00
|
|
|
ptr->getKind() == IROperandOwner::Kind::OperationStmt;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// We need to teach the LLVM cast/dyn_cast etc logic how to cast from an
|
|
|
|
/// IROperandOwner* to Operation*. This can't be done with a simple pointer to
|
|
|
|
/// pointer cast because the pointer adjustment depends on whether the Owner is
|
|
|
|
/// dynamically an Instruction or Statement, because of multiple inheritance.
|
|
|
|
Operation *
|
|
|
|
llvm::cast_convert_val<mlir::Operation, mlir::IROperandOwner *,
|
|
|
|
mlir::IROperandOwner *>::doit(const mlir::IROperandOwner
|
|
|
|
*value) {
|
|
|
|
const Operation *op;
|
|
|
|
if (auto *ptr = dyn_cast<OperationStmt>(value))
|
|
|
|
op = ptr;
|
|
|
|
else
|
2018-11-16 10:31:15 +08:00
|
|
|
op = cast<Instruction>(value);
|
2018-10-29 01:03:19 +08:00
|
|
|
return const_cast<Operation *>(op);
|
|
|
|
}
|
2018-10-23 04:08:27 +08:00
|
|
|
|
2018-09-10 11:40:23 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-09-27 06:06:38 +08:00
|
|
|
// OpState trait class.
|
2018-09-10 11:40:23 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-10-22 10:49:31 +08:00
|
|
|
// The fallback for the parser is to reject the short form.
|
|
|
|
bool OpState::parse(OpAsmParser *parser, OperationState *result) {
|
|
|
|
return parser->emitError(parser->getNameLoc(), "has no concise form");
|
|
|
|
}
|
|
|
|
|
|
|
|
// The fallback for the printer is to print it the longhand form.
|
|
|
|
void OpState::print(OpAsmPrinter *p) const {
|
|
|
|
p->printDefaultOp(getOperation());
|
|
|
|
}
|
|
|
|
|
2018-09-10 11:40:23 +08:00
|
|
|
/// Emit an error about fatal conditions with this operation, reporting up to
|
|
|
|
/// any diagnostic handlers that may be listening. NOTE: This may terminate
|
|
|
|
/// the containing application, only use when the IR is in an inconsistent
|
|
|
|
/// state.
|
2018-09-27 06:06:38 +08:00
|
|
|
void OpState::emitError(const Twine &message) const {
|
2018-09-10 11:40:23 +08:00
|
|
|
getOperation()->emitError(message);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Emit an error with the op name prefixed, like "'dim' op " which is
|
|
|
|
/// convenient for verifiers.
|
2018-09-27 06:06:38 +08:00
|
|
|
bool OpState::emitOpError(const Twine &message) const {
|
2018-09-10 11:40:23 +08:00
|
|
|
return getOperation()->emitOpError(message);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Emit a warning about this operation, reporting up to any diagnostic
|
|
|
|
/// handlers that may be listening.
|
2018-09-27 06:06:38 +08:00
|
|
|
void OpState::emitWarning(const Twine &message) const {
|
2018-09-10 11:40:23 +08:00
|
|
|
getOperation()->emitWarning(message);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Emit a note about this operation, reporting up to any diagnostic
|
|
|
|
/// handlers that may be listening.
|
2018-09-27 06:06:38 +08:00
|
|
|
void OpState::emitNote(const Twine &message) const {
|
2018-09-10 11:40:23 +08:00
|
|
|
getOperation()->emitNote(message);
|
|
|
|
}
|
2018-09-27 01:07:16 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Op Trait implementations
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-09-27 12:18:42 +08:00
|
|
|
bool OpTrait::impl::verifyZeroOperands(const Operation *op) {
|
|
|
|
if (op->getNumOperands() != 0)
|
|
|
|
return op->emitOpError("requires zero operands");
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpTrait::impl::verifyOneOperand(const Operation *op) {
|
|
|
|
if (op->getNumOperands() != 1)
|
|
|
|
return op->emitOpError("requires a single operand");
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpTrait::impl::verifyNOperands(const Operation *op, unsigned numOperands) {
|
2018-10-10 06:04:27 +08:00
|
|
|
if (op->getNumOperands() != numOperands) {
|
|
|
|
return op->emitOpError("expected " + Twine(numOperands) +
|
|
|
|
" operands, but found " +
|
|
|
|
Twine(op->getNumOperands()));
|
|
|
|
}
|
2018-09-27 12:18:42 +08:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpTrait::impl::verifyAtLeastNOperands(const Operation *op,
|
|
|
|
unsigned numOperands) {
|
|
|
|
if (op->getNumOperands() < numOperands)
|
|
|
|
return op->emitOpError("expected " + Twine(numOperands) +
|
|
|
|
" or more operands");
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-11-07 07:37:39 +08:00
|
|
|
/// If this is a vector type, or a tensor type, return the scalar element type
|
|
|
|
/// that it is built around, otherwise return the type unmodified.
|
|
|
|
static Type getTensorOrVectorElementType(Type type) {
|
|
|
|
if (auto vec = type.dyn_cast<VectorType>())
|
|
|
|
return vec.getElementType();
|
|
|
|
|
|
|
|
// Look through tensor<vector<...>> to find the underlying element type.
|
|
|
|
if (auto tensor = type.dyn_cast<TensorType>())
|
|
|
|
return getTensorOrVectorElementType(tensor.getElementType());
|
|
|
|
return type;
|
|
|
|
}
|
|
|
|
|
Enable arithmetics for index types.
Arithmetic and comparison instructions are necessary to implement, e.g.,
control flow when lowering MLFunctions to CFGFunctions. (While it is possible
to replace some of the arithmetics by affine_apply instructions for loop
bounds, it is still necessary for loop bounds checking, steps, if-conditions,
non-trivial memref subscripts, etc.) Furthermore, working with indirect
accesses in, e.g., lookup tables for large embeddings, may require operating on
tensors of indexes. For example, the equivalents to C code "LUT[Index[i]]" or
"ResultIndex[i] = i + j" where i, j are loop induction variables require the
arithmetics on indices as well as the possibility to operate on tensors
thereof. Allow arithmetic and comparison operations to apply to index types by
declaring them integer-like. Allow tensors whose element type is index for
indirection purposes.
The absence of vectors with "index" element type is explicitly tested, but the
only justification for this restriction in the CL introducing the test is
"because we don't need them". Do NOT enable vectors of index types, although
it makes vector and tensor types inconsistent with respect to allowed element
types.
PiperOrigin-RevId: 220614055
2018-11-08 20:04:32 +08:00
|
|
|
// Checks if the given type is an integer or an index type. Following LLVM's
|
|
|
|
// convention, returns true if the check fails and false otherwise.
|
|
|
|
static inline bool checkIntegerLikeType(Type type) {
|
|
|
|
return !(type.isa<IntegerType>() || type.isa<IndexType>());
|
|
|
|
}
|
|
|
|
|
2018-11-07 07:37:39 +08:00
|
|
|
bool OpTrait::impl::verifyOperandsAreIntegerLike(const Operation *op) {
|
|
|
|
for (auto *operand : op->getOperands()) {
|
Enable arithmetics for index types.
Arithmetic and comparison instructions are necessary to implement, e.g.,
control flow when lowering MLFunctions to CFGFunctions. (While it is possible
to replace some of the arithmetics by affine_apply instructions for loop
bounds, it is still necessary for loop bounds checking, steps, if-conditions,
non-trivial memref subscripts, etc.) Furthermore, working with indirect
accesses in, e.g., lookup tables for large embeddings, may require operating on
tensors of indexes. For example, the equivalents to C code "LUT[Index[i]]" or
"ResultIndex[i] = i + j" where i, j are loop induction variables require the
arithmetics on indices as well as the possibility to operate on tensors
thereof. Allow arithmetic and comparison operations to apply to index types by
declaring them integer-like. Allow tensors whose element type is index for
indirection purposes.
The absence of vectors with "index" element type is explicitly tested, but the
only justification for this restriction in the CL introducing the test is
"because we don't need them". Do NOT enable vectors of index types, although
it makes vector and tensor types inconsistent with respect to allowed element
types.
PiperOrigin-RevId: 220614055
2018-11-08 20:04:32 +08:00
|
|
|
auto type = getTensorOrVectorElementType(operand->getType());
|
|
|
|
if (checkIntegerLikeType(type))
|
|
|
|
return op->emitOpError("requires an integer or index type");
|
2018-11-07 07:37:39 +08:00
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpTrait::impl::verifySameTypeOperands(const Operation *op) {
|
|
|
|
// Zero or one operand always have the "same" type.
|
|
|
|
unsigned nOperands = op->getNumOperands();
|
|
|
|
if (nOperands < 2)
|
|
|
|
return false;
|
|
|
|
|
|
|
|
auto type = op->getOperand(0)->getType();
|
|
|
|
for (unsigned i = 1; i < nOperands; ++i) {
|
|
|
|
if (op->getOperand(i)->getType() != type)
|
|
|
|
return op->emitOpError("requires all operands to have the same type");
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-09-27 12:18:42 +08:00
|
|
|
bool OpTrait::impl::verifyZeroResult(const Operation *op) {
|
|
|
|
if (op->getNumResults() != 0)
|
|
|
|
return op->emitOpError("requires zero results");
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpTrait::impl::verifyOneResult(const Operation *op) {
|
|
|
|
if (op->getNumResults() != 1)
|
|
|
|
return op->emitOpError("requires one result");
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpTrait::impl::verifyNResults(const Operation *op, unsigned numOperands) {
|
|
|
|
if (op->getNumResults() != numOperands)
|
|
|
|
return op->emitOpError("expected " + Twine(numOperands) + " results");
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpTrait::impl::verifyAtLeastNResults(const Operation *op,
|
|
|
|
unsigned numOperands) {
|
|
|
|
if (op->getNumResults() < numOperands)
|
|
|
|
return op->emitOpError("expected " + Twine(numOperands) +
|
|
|
|
" or more results");
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-09-27 01:07:16 +08:00
|
|
|
bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
|
2018-10-31 05:59:22 +08:00
|
|
|
auto type = op->getResult(0)->getType();
|
2018-09-27 01:07:16 +08:00
|
|
|
for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
|
|
|
|
if (op->getResult(i)->getType() != type)
|
|
|
|
return op->emitOpError(
|
|
|
|
"requires the same type for all operands and results");
|
|
|
|
}
|
|
|
|
for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) {
|
|
|
|
if (op->getOperand(i)->getType() != type)
|
|
|
|
return op->emitOpError(
|
|
|
|
"requires the same type for all operands and results");
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-11-16 01:56:06 +08:00
|
|
|
static bool verifyBBArguments(
|
|
|
|
llvm::iterator_range<Operation::const_operand_iterator> operands,
|
|
|
|
const BasicBlock *destBB, const Operation *op) {
|
|
|
|
unsigned operandCount = std::distance(operands.begin(), operands.end());
|
|
|
|
if (operandCount != destBB->getNumArguments()) {
|
|
|
|
op->emitError("branch has " + Twine(operandCount) +
|
|
|
|
" operands, but target block has " +
|
|
|
|
Twine(destBB->getNumArguments()));
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto operandIt = operands.begin();
|
|
|
|
for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) {
|
|
|
|
if ((*operandIt)->getType() != destBB->getArgument(i)->getType()) {
|
|
|
|
op->emitError("type mismatch in bb argument #" + Twine(i));
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
static bool verifyTerminatorSuccessors(const Operation *op) {
|
|
|
|
// Verify that the operands lines up with the BB arguments in the successor.
|
|
|
|
const Function *fn = op->getOperationFunction();
|
|
|
|
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
|
|
|
|
auto *succ = op->getSuccessor(i);
|
|
|
|
if (succ->getFunction() != fn) {
|
|
|
|
op->emitError("reference to block defined in another function");
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
if (verifyBBArguments(op->getSuccessorOperands(i), succ, op))
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-11-14 01:49:27 +08:00
|
|
|
bool OpTrait::impl::verifyIsTerminator(const Operation *op) {
|
|
|
|
// Verify that the operation is at the end of the respective parent block.
|
|
|
|
if (auto *stmt = dyn_cast<OperationStmt>(op)) {
|
|
|
|
StmtBlock *block = stmt->getBlock();
|
|
|
|
if (!block || !isa<MLFunction>(block) || &block->back() != stmt)
|
|
|
|
return op->emitOpError("must be the last statement in the ML function");
|
|
|
|
} else {
|
2018-11-16 10:31:15 +08:00
|
|
|
const Instruction *inst = cast<Instruction>(op);
|
2018-11-14 01:49:27 +08:00
|
|
|
const BasicBlock *block = inst->getBlock();
|
2018-11-16 07:20:23 +08:00
|
|
|
if (!block || &block->back() != inst)
|
2018-11-14 01:49:27 +08:00
|
|
|
return op->emitOpError(
|
|
|
|
"must be the last instruction in the parent basic block.");
|
|
|
|
}
|
|
|
|
|
2018-11-16 01:56:06 +08:00
|
|
|
// Verify the state of the successor blocks.
|
|
|
|
if (op->getNumSuccessors() != 0 && verifyTerminatorSuccessors(op))
|
|
|
|
return true;
|
2018-11-14 01:49:27 +08:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-09-27 01:07:16 +08:00
|
|
|
bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
|
|
|
|
for (auto *result : op->getResults()) {
|
2018-10-31 05:59:22 +08:00
|
|
|
if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
|
2018-09-27 01:07:16 +08:00
|
|
|
return op->emitOpError("requires a floating point type");
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
|
|
|
|
for (auto *result : op->getResults()) {
|
Enable arithmetics for index types.
Arithmetic and comparison instructions are necessary to implement, e.g.,
control flow when lowering MLFunctions to CFGFunctions. (While it is possible
to replace some of the arithmetics by affine_apply instructions for loop
bounds, it is still necessary for loop bounds checking, steps, if-conditions,
non-trivial memref subscripts, etc.) Furthermore, working with indirect
accesses in, e.g., lookup tables for large embeddings, may require operating on
tensors of indexes. For example, the equivalents to C code "LUT[Index[i]]" or
"ResultIndex[i] = i + j" where i, j are loop induction variables require the
arithmetics on indices as well as the possibility to operate on tensors
thereof. Allow arithmetic and comparison operations to apply to index types by
declaring them integer-like. Allow tensors whose element type is index for
indirection purposes.
The absence of vectors with "index" element type is explicitly tested, but the
only justification for this restriction in the CL introducing the test is
"because we don't need them". Do NOT enable vectors of index types, although
it makes vector and tensor types inconsistent with respect to allowed element
types.
PiperOrigin-RevId: 220614055
2018-11-08 20:04:32 +08:00
|
|
|
auto type = getTensorOrVectorElementType(result->getType());
|
|
|
|
if (checkIntegerLikeType(type))
|
|
|
|
return op->emitOpError("requires an integer or index type");
|
2018-09-27 01:07:16 +08:00
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// BinaryOp implementation
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// These functions are out-of-line implementations of the methods in BinaryOp,
|
|
|
|
// which avoids them being template instantiated/duplicated.
|
|
|
|
|
|
|
|
void impl::buildBinaryOp(Builder *builder, OperationState *result,
|
|
|
|
SSAValue *lhs, SSAValue *rhs) {
|
|
|
|
assert(lhs->getType() == rhs->getType());
|
|
|
|
result->addOperands({lhs, rhs});
|
|
|
|
result->types.push_back(lhs->getType());
|
|
|
|
}
|
|
|
|
|
|
|
|
bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
|
|
|
|
SmallVector<OpAsmParser::OperandType, 2> ops;
|
2018-10-31 05:59:22 +08:00
|
|
|
Type type;
|
2018-09-27 01:07:16 +08:00
|
|
|
return parser->parseOperandList(ops, 2) ||
|
|
|
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
|
|
|
parser->parseColonType(type) ||
|
|
|
|
parser->resolveOperands(ops, type, result->operands) ||
|
|
|
|
parser->addTypeToList(type, result->types);
|
|
|
|
}
|
|
|
|
|
|
|
|
void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
|
2018-10-23 00:00:03 +08:00
|
|
|
*p << op->getName() << ' ' << *op->getOperand(0) << ", "
|
2018-09-27 01:07:16 +08:00
|
|
|
<< *op->getOperand(1);
|
|
|
|
p->printOptionalAttrDict(op->getAttrs());
|
2018-10-31 05:59:22 +08:00
|
|
|
*p << " : " << op->getResult(0)->getType();
|
2018-09-27 01:07:16 +08:00
|
|
|
}
|
2018-10-23 00:00:03 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// CastOp implementation
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void impl::buildCastOp(Builder *builder, OperationState *result,
|
2018-10-31 05:59:22 +08:00
|
|
|
SSAValue *source, Type destType) {
|
2018-10-23 00:00:03 +08:00
|
|
|
result->addOperands(source);
|
|
|
|
result->addTypes(destType);
|
|
|
|
}
|
|
|
|
|
|
|
|
bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
|
|
|
|
OpAsmParser::OperandType srcInfo;
|
2018-10-31 05:59:22 +08:00
|
|
|
Type srcType, dstType;
|
2018-10-23 00:00:03 +08:00
|
|
|
return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
|
|
|
|
parser->resolveOperand(srcInfo, srcType, result->operands) ||
|
|
|
|
parser->parseKeywordType("to", dstType) ||
|
|
|
|
parser->addTypeToList(dstType, result->types);
|
|
|
|
}
|
|
|
|
|
|
|
|
void impl::printCastOp(const Operation *op, OpAsmPrinter *p) {
|
|
|
|
*p << op->getName() << ' ' << *op->getOperand(0) << " : "
|
2018-10-31 05:59:22 +08:00
|
|
|
<< op->getOperand(0)->getType() << " to " << op->getResult(0)->getType();
|
2018-10-23 00:00:03 +08:00
|
|
|
}
|