2018-06-24 07:03:42 +08:00
|
|
|
//===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The MLIR Authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// =============================================================================
|
|
|
|
//
|
|
|
|
// This file implements the MLIR AsmPrinter class, which is used to implement
|
|
|
|
// the various print() methods on the core IR objects.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-06-30 09:09:29 +08:00
|
|
|
#include "mlir/IR/AffineExpr.h"
|
|
|
|
#include "mlir/IR/AffineMap.h"
|
2018-07-05 11:45:39 +08:00
|
|
|
#include "mlir/IR/Attributes.h"
|
2018-10-11 05:23:30 +08:00
|
|
|
#include "mlir/IR/BuiltinOps.h"
|
2018-06-24 07:03:42 +08:00
|
|
|
#include "mlir/IR/CFGFunction.h"
|
2018-08-08 05:24:38 +08:00
|
|
|
#include "mlir/IR/IntegerSet.h"
|
2018-06-29 08:02:32 +08:00
|
|
|
#include "mlir/IR/MLFunction.h"
|
2018-06-24 07:03:42 +08:00
|
|
|
#include "mlir/IR/Module.h"
|
2018-07-25 07:07:22 +08:00
|
|
|
#include "mlir/IR/OpImplementation.h"
|
2018-07-14 04:03:13 +08:00
|
|
|
#include "mlir/IR/Statements.h"
|
2018-07-27 09:09:20 +08:00
|
|
|
#include "mlir/IR/StmtVisitor.h"
|
2018-06-24 07:03:42 +08:00
|
|
|
#include "mlir/IR/Types.h"
|
|
|
|
#include "mlir/Support/STLExtras.h"
|
2018-08-16 00:09:54 +08:00
|
|
|
#include "llvm/ADT/APFloat.h"
|
2018-06-24 07:03:42 +08:00
|
|
|
#include "llvm/ADT/DenseMap.h"
|
2018-08-02 01:43:18 +08:00
|
|
|
#include "llvm/ADT/SmallString.h"
|
|
|
|
#include "llvm/ADT/StringExtras.h"
|
|
|
|
#include "llvm/ADT/StringSet.h"
|
2018-10-11 05:23:30 +08:00
|
|
|
|
2018-06-24 07:03:42 +08:00
|
|
|
using namespace mlir;
|
|
|
|
|
2018-07-19 01:16:05 +08:00
|
|
|
void Identifier::print(raw_ostream &os) const { os << str(); }
|
2018-06-24 07:03:42 +08:00
|
|
|
|
2018-07-19 01:16:05 +08:00
|
|
|
void Identifier::dump() const { print(llvm::errs()); }
|
2018-07-05 11:45:39 +08:00
|
|
|
|
2018-10-10 13:08:52 +08:00
|
|
|
void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
|
|
|
|
|
|
|
|
void OperationName::dump() const { print(llvm::errs()); }
|
|
|
|
|
2018-07-25 07:07:22 +08:00
|
|
|
OpAsmPrinter::~OpAsmPrinter() {}
|
|
|
|
|
2018-07-18 07:56:54 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-07-21 00:35:47 +08:00
|
|
|
// ModuleState
|
2018-07-18 07:56:54 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
2018-07-19 01:16:05 +08:00
|
|
|
class ModuleState {
|
|
|
|
public:
|
2018-10-22 10:49:31 +08:00
|
|
|
/// This is the current context if it is knowable, otherwise this is null.
|
|
|
|
MLIRContext *const context;
|
2018-07-18 07:56:54 +08:00
|
|
|
|
2018-10-22 10:49:31 +08:00
|
|
|
explicit ModuleState(MLIRContext *context) : context(context) {}
|
2018-07-18 07:56:54 +08:00
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
// Initializes module state, populating affine map state.
|
|
|
|
void initialize(const Module *module);
|
2018-07-18 07:56:54 +08:00
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
int getAffineMapId(AffineMap affineMap) const {
|
2018-07-18 07:56:54 +08:00
|
|
|
auto it = affineMapIds.find(affineMap);
|
|
|
|
if (it == affineMapIds.end()) {
|
|
|
|
return -1;
|
|
|
|
}
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
ArrayRef<AffineMap> getAffineMapIds() const { return affineMapsById; }
|
2018-07-21 00:35:47 +08:00
|
|
|
|
2018-10-11 00:45:59 +08:00
|
|
|
int getIntegerSetId(IntegerSet integerSet) const {
|
2018-08-08 05:24:38 +08:00
|
|
|
auto it = integerSetIds.find(integerSet);
|
|
|
|
if (it == integerSetIds.end()) {
|
|
|
|
return -1;
|
|
|
|
}
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
|
2018-10-11 00:45:59 +08:00
|
|
|
ArrayRef<IntegerSet> getIntegerSetIds() const { return integerSetsById; }
|
2018-08-08 05:24:38 +08:00
|
|
|
|
2018-07-19 01:16:05 +08:00
|
|
|
private:
|
2018-10-10 07:39:24 +08:00
|
|
|
void recordAffineMapReference(AffineMap affineMap) {
|
2018-07-21 00:35:47 +08:00
|
|
|
if (affineMapIds.count(affineMap) == 0) {
|
2018-07-25 00:48:31 +08:00
|
|
|
affineMapIds[affineMap] = affineMapsById.size();
|
|
|
|
affineMapsById.push_back(affineMap);
|
2018-07-21 00:35:47 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-10-11 00:45:59 +08:00
|
|
|
void recordIntegerSetReference(IntegerSet integerSet) {
|
2018-08-08 05:24:38 +08:00
|
|
|
if (integerSetIds.count(integerSet) == 0) {
|
|
|
|
integerSetIds[integerSet] = integerSetsById.size();
|
|
|
|
integerSetsById.push_back(integerSet);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-08-25 14:38:14 +08:00
|
|
|
// Return true if this map could be printed using the shorthand form.
|
2018-10-10 07:39:24 +08:00
|
|
|
static bool hasShorthandForm(AffineMap boundMap) {
|
|
|
|
if (boundMap.isSingleConstant())
|
2018-08-25 14:38:14 +08:00
|
|
|
return true;
|
|
|
|
|
|
|
|
// Check if the affine map is single dim id or single symbol identity -
|
|
|
|
// (i)->(i) or ()[s]->(i)
|
2018-10-10 07:39:24 +08:00
|
|
|
return boundMap.getNumInputs() == 1 && boundMap.getNumResults() == 1 &&
|
|
|
|
(boundMap.getResult(0).isa<AffineDimExpr>() ||
|
|
|
|
boundMap.getResult(0).isa<AffineSymbolExpr>());
|
2018-08-25 14:38:14 +08:00
|
|
|
}
|
|
|
|
|
2018-07-18 07:56:54 +08:00
|
|
|
// Visit functions.
|
|
|
|
void visitFunction(const Function *fn);
|
|
|
|
void visitExtFunction(const ExtFunction *fn);
|
|
|
|
void visitCFGFunction(const CFGFunction *fn);
|
|
|
|
void visitMLFunction(const MLFunction *fn);
|
2018-08-08 05:24:38 +08:00
|
|
|
void visitStatement(const Statement *stmt);
|
|
|
|
void visitForStmt(const ForStmt *forStmt);
|
|
|
|
void visitIfStmt(const IfStmt *ifStmt);
|
|
|
|
void visitOperationStmt(const OperationStmt *opStmt);
|
2018-07-18 07:56:54 +08:00
|
|
|
void visitType(const Type *type);
|
2018-07-19 07:29:21 +08:00
|
|
|
void visitAttribute(const Attribute *attr);
|
|
|
|
void visitOperation(const Operation *op);
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
DenseMap<AffineMap, int> affineMapIds;
|
|
|
|
std::vector<AffineMap> affineMapsById;
|
2018-08-08 05:24:38 +08:00
|
|
|
|
2018-10-11 00:45:59 +08:00
|
|
|
DenseMap<IntegerSet, int> integerSetIds;
|
|
|
|
std::vector<IntegerSet> integerSetsById;
|
2018-07-18 07:56:54 +08:00
|
|
|
};
|
2018-07-24 02:44:40 +08:00
|
|
|
} // end anonymous namespace
|
2018-07-18 07:56:54 +08:00
|
|
|
|
|
|
|
// TODO Support visiting other types/instructions when implemented.
|
|
|
|
void ModuleState::visitType(const Type *type) {
|
2018-07-29 00:36:25 +08:00
|
|
|
if (auto *funcType = dyn_cast<FunctionType>(type)) {
|
2018-07-18 07:56:54 +08:00
|
|
|
// Visit input and result types for functions.
|
2018-07-29 00:36:25 +08:00
|
|
|
for (auto *input : funcType->getInputs())
|
2018-07-18 07:56:54 +08:00
|
|
|
visitType(input);
|
2018-07-29 00:36:25 +08:00
|
|
|
for (auto *result : funcType->getResults())
|
2018-07-18 07:56:54 +08:00
|
|
|
visitType(result);
|
2018-07-29 00:36:25 +08:00
|
|
|
} else if (auto *memref = dyn_cast<MemRefType>(type)) {
|
2018-07-18 07:56:54 +08:00
|
|
|
// Visit affine maps in memref type.
|
2018-10-10 07:39:24 +08:00
|
|
|
for (auto map : memref->getAffineMaps()) {
|
2018-07-18 07:56:54 +08:00
|
|
|
recordAffineMapReference(map);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-07-19 07:29:21 +08:00
|
|
|
void ModuleState::visitAttribute(const Attribute *attr) {
|
2018-07-29 00:36:25 +08:00
|
|
|
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr)) {
|
|
|
|
recordAffineMapReference(mapAttr->getValue());
|
2018-08-08 05:24:38 +08:00
|
|
|
} else if (auto *arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
|
|
|
for (auto elt : arrayAttr->getValue()) {
|
2018-07-19 07:29:21 +08:00
|
|
|
visitAttribute(elt);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void ModuleState::visitOperation(const Operation *op) {
|
2018-08-01 09:32:59 +08:00
|
|
|
// Visit all the types used in the operation.
|
|
|
|
for (auto *operand : op->getOperands())
|
|
|
|
visitType(operand->getType());
|
|
|
|
for (auto *result : op->getResults())
|
|
|
|
visitType(result->getType());
|
|
|
|
|
|
|
|
// Visit each of the attributes.
|
|
|
|
for (auto elt : op->getAttrs())
|
2018-07-19 07:29:21 +08:00
|
|
|
visitAttribute(elt.second);
|
|
|
|
}
|
|
|
|
|
2018-07-18 07:56:54 +08:00
|
|
|
void ModuleState::visitExtFunction(const ExtFunction *fn) {
|
|
|
|
visitType(fn->getType());
|
|
|
|
}
|
|
|
|
|
|
|
|
void ModuleState::visitCFGFunction(const CFGFunction *fn) {
|
|
|
|
visitType(fn->getType());
|
2018-07-19 07:29:21 +08:00
|
|
|
for (auto &block : *fn) {
|
|
|
|
for (auto &op : block.getOperations()) {
|
|
|
|
visitOperation(&op);
|
|
|
|
}
|
|
|
|
}
|
2018-07-18 07:56:54 +08:00
|
|
|
}
|
|
|
|
|
2018-08-08 05:24:38 +08:00
|
|
|
void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
|
2018-08-29 06:26:20 +08:00
|
|
|
recordIntegerSetReference(ifStmt->getIntegerSet());
|
2018-08-09 02:14:57 +08:00
|
|
|
for (auto &childStmt : *ifStmt->getThen())
|
2018-08-08 05:24:38 +08:00
|
|
|
visitStatement(&childStmt);
|
2018-08-09 02:14:57 +08:00
|
|
|
if (ifStmt->hasElse())
|
|
|
|
for (auto &childStmt : *ifStmt->getElse())
|
2018-08-08 05:24:38 +08:00
|
|
|
visitStatement(&childStmt);
|
|
|
|
}
|
|
|
|
|
|
|
|
void ModuleState::visitForStmt(const ForStmt *forStmt) {
|
2018-10-10 07:39:24 +08:00
|
|
|
AffineMap lbMap = forStmt->getLowerBoundMap();
|
2018-08-25 14:38:14 +08:00
|
|
|
if (!hasShorthandForm(lbMap))
|
|
|
|
recordAffineMapReference(lbMap);
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
AffineMap ubMap = forStmt->getUpperBoundMap();
|
2018-08-25 14:38:14 +08:00
|
|
|
if (!hasShorthandForm(ubMap))
|
|
|
|
recordAffineMapReference(ubMap);
|
|
|
|
|
2018-08-08 05:24:38 +08:00
|
|
|
for (auto &childStmt : *forStmt)
|
|
|
|
visitStatement(&childStmt);
|
|
|
|
}
|
|
|
|
|
|
|
|
void ModuleState::visitOperationStmt(const OperationStmt *opStmt) {
|
2018-08-16 00:09:54 +08:00
|
|
|
for (auto attr : opStmt->getAttrs())
|
|
|
|
visitAttribute(attr.second);
|
2018-08-08 05:24:38 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void ModuleState::visitStatement(const Statement *stmt) {
|
|
|
|
switch (stmt->getKind()) {
|
|
|
|
case Statement::Kind::If:
|
|
|
|
return visitIfStmt(cast<IfStmt>(stmt));
|
|
|
|
case Statement::Kind::For:
|
|
|
|
return visitForStmt(cast<ForStmt>(stmt));
|
|
|
|
case Statement::Kind::Operation:
|
|
|
|
return visitOperationStmt(cast<OperationStmt>(stmt));
|
|
|
|
default:
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-07-18 07:56:54 +08:00
|
|
|
void ModuleState::visitMLFunction(const MLFunction *fn) {
|
|
|
|
visitType(fn->getType());
|
2018-08-08 05:24:38 +08:00
|
|
|
for (auto &stmt : *fn) {
|
|
|
|
ModuleState::visitStatement(&stmt);
|
|
|
|
}
|
2018-07-18 07:56:54 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void ModuleState::visitFunction(const Function *fn) {
|
|
|
|
switch (fn->getKind()) {
|
2018-07-19 01:16:05 +08:00
|
|
|
case Function::Kind::ExtFunc:
|
|
|
|
return visitExtFunction(cast<ExtFunction>(fn));
|
|
|
|
case Function::Kind::CFGFunc:
|
|
|
|
return visitCFGFunction(cast<CFGFunction>(fn));
|
|
|
|
case Function::Kind::MLFunc:
|
|
|
|
return visitMLFunction(cast<MLFunction>(fn));
|
2018-07-18 07:56:54 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-08-08 05:24:38 +08:00
|
|
|
// Initializes module state, populating affine map and integer set state.
|
2018-07-21 00:35:47 +08:00
|
|
|
void ModuleState::initialize(const Module *module) {
|
2018-07-26 05:08:16 +08:00
|
|
|
for (auto &fn : *module) {
|
|
|
|
visitFunction(&fn);
|
2018-07-21 00:35:47 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ModulePrinter
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
class ModulePrinter {
|
|
|
|
public:
|
|
|
|
ModulePrinter(raw_ostream &os, ModuleState &state) : os(os), state(state) {}
|
|
|
|
explicit ModulePrinter(const ModulePrinter &printer)
|
|
|
|
: os(printer.os), state(printer.state) {}
|
|
|
|
|
|
|
|
template <typename Container, typename UnaryFunctor>
|
|
|
|
inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
|
|
|
|
interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
|
|
|
|
}
|
|
|
|
|
|
|
|
void print(const Module *module);
|
2018-08-22 08:55:22 +08:00
|
|
|
void printFunctionReference(const Function *func);
|
2018-07-25 07:07:22 +08:00
|
|
|
void printAttribute(const Attribute *attr);
|
|
|
|
void printType(const Type *type);
|
2018-07-21 00:35:47 +08:00
|
|
|
void print(const Function *fn);
|
|
|
|
void print(const ExtFunction *fn);
|
|
|
|
void print(const CFGFunction *fn);
|
|
|
|
void print(const MLFunction *fn);
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
void printAffineMap(AffineMap map);
|
2018-10-09 04:47:18 +08:00
|
|
|
void printAffineExpr(AffineExpr expr);
|
|
|
|
void printAffineConstraint(AffineExpr expr, bool isEq);
|
2018-10-11 00:45:59 +08:00
|
|
|
void printIntegerSet(IntegerSet set);
|
2018-07-21 00:35:47 +08:00
|
|
|
|
|
|
|
protected:
|
|
|
|
raw_ostream &os;
|
|
|
|
ModuleState &state;
|
|
|
|
|
|
|
|
void printFunctionSignature(const Function *fn);
|
2018-09-19 07:36:26 +08:00
|
|
|
void printFunctionAttributes(const Function *fn);
|
|
|
|
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
|
|
|
ArrayRef<const char *> elidedAttrs = {});
|
2018-08-07 02:54:39 +08:00
|
|
|
void printFunctionResultType(const FunctionType *type);
|
2018-07-21 00:35:47 +08:00
|
|
|
void printAffineMapId(int affineMapId) const;
|
2018-10-10 07:39:24 +08:00
|
|
|
void printAffineMapReference(AffineMap affineMap);
|
2018-08-08 05:24:38 +08:00
|
|
|
void printIntegerSetId(int integerSetId) const;
|
2018-10-11 00:45:59 +08:00
|
|
|
void printIntegerSetReference(IntegerSet integerSet);
|
2018-10-19 04:54:44 +08:00
|
|
|
void printDenseElementsAttr(const DenseElementsAttr *attr);
|
2018-07-21 00:35:47 +08:00
|
|
|
|
2018-08-01 07:21:36 +08:00
|
|
|
/// This enum is used to represent the binding stength of the enclosing
|
2018-10-10 01:59:27 +08:00
|
|
|
/// context that an AffineExprStorage is being printed in, so we can
|
2018-10-09 04:47:18 +08:00
|
|
|
/// intelligently produce parens.
|
2018-08-01 07:21:36 +08:00
|
|
|
enum class BindingStrength {
|
|
|
|
Weak, // + and -
|
|
|
|
Strong, // All other binary operators.
|
|
|
|
};
|
2018-10-09 04:47:18 +08:00
|
|
|
void printAffineExprInternal(AffineExpr expr,
|
2018-08-01 07:21:36 +08:00
|
|
|
BindingStrength enclosingTightness);
|
2018-07-21 00:35:47 +08:00
|
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
|
2018-07-18 07:56:54 +08:00
|
|
|
// Prints function with initialized module state.
|
2018-07-21 00:35:47 +08:00
|
|
|
void ModulePrinter::print(const Function *fn) {
|
2018-07-18 07:56:54 +08:00
|
|
|
switch (fn->getKind()) {
|
2018-07-19 01:16:05 +08:00
|
|
|
case Function::Kind::ExtFunc:
|
|
|
|
return print(cast<ExtFunction>(fn));
|
|
|
|
case Function::Kind::CFGFunc:
|
|
|
|
return print(cast<CFGFunction>(fn));
|
|
|
|
case Function::Kind::MLFunc:
|
|
|
|
return print(cast<MLFunction>(fn));
|
2018-07-18 07:56:54 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Prints affine map identifier.
|
2018-07-21 00:35:47 +08:00
|
|
|
void ModulePrinter::printAffineMapId(int affineMapId) const {
|
2018-07-18 07:56:54 +08:00
|
|
|
os << "#map" << affineMapId;
|
|
|
|
}
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
void ModulePrinter::printAffineMapReference(AffineMap affineMap) {
|
2018-07-21 00:35:47 +08:00
|
|
|
int mapId = state.getAffineMapId(affineMap);
|
2018-07-19 07:29:21 +08:00
|
|
|
if (mapId >= 0) {
|
|
|
|
// Map will be printed at top of module so print reference to its id.
|
|
|
|
printAffineMapId(mapId);
|
|
|
|
} else {
|
|
|
|
// Map not in module state so print inline.
|
2018-10-10 07:39:24 +08:00
|
|
|
affineMap.print(os);
|
2018-07-19 07:29:21 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-08-08 05:24:38 +08:00
|
|
|
// Prints integer set identifier.
|
|
|
|
void ModulePrinter::printIntegerSetId(int integerSetId) const {
|
|
|
|
os << "@@set" << integerSetId;
|
|
|
|
}
|
|
|
|
|
2018-10-11 00:45:59 +08:00
|
|
|
void ModulePrinter::printIntegerSetReference(IntegerSet integerSet) {
|
2018-08-08 05:24:38 +08:00
|
|
|
int setId;
|
|
|
|
if ((setId = state.getIntegerSetId(integerSet)) >= 0) {
|
|
|
|
// The set will be printed at top of module; so print reference to its id.
|
|
|
|
printIntegerSetId(setId);
|
|
|
|
} else {
|
|
|
|
// Set not in module state so print inline.
|
2018-10-11 00:45:59 +08:00
|
|
|
integerSet.print(os);
|
2018-08-08 05:24:38 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void ModulePrinter::print(const Module *module) {
|
2018-07-25 00:48:31 +08:00
|
|
|
for (const auto &map : state.getAffineMapIds()) {
|
|
|
|
printAffineMapId(state.getAffineMapId(map));
|
2018-07-18 07:56:54 +08:00
|
|
|
os << " = ";
|
2018-10-10 07:39:24 +08:00
|
|
|
map.print(os);
|
2018-07-18 07:56:54 +08:00
|
|
|
os << '\n';
|
|
|
|
}
|
2018-08-08 05:24:38 +08:00
|
|
|
for (const auto &set : state.getIntegerSetIds()) {
|
|
|
|
printIntegerSetId(state.getIntegerSetId(set));
|
|
|
|
os << " = ";
|
2018-10-11 00:45:59 +08:00
|
|
|
set.print(os);
|
2018-08-08 05:24:38 +08:00
|
|
|
os << '\n';
|
|
|
|
}
|
2018-07-26 05:08:16 +08:00
|
|
|
for (auto const &fn : *module)
|
|
|
|
print(&fn);
|
2018-07-18 07:56:54 +08:00
|
|
|
}
|
|
|
|
|
2018-08-16 00:09:54 +08:00
|
|
|
/// Print a floating point value in a way that the parser will be able to
|
|
|
|
/// round-trip losslessly.
|
2018-10-21 09:31:49 +08:00
|
|
|
static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
|
2018-08-16 00:09:54 +08:00
|
|
|
// We would like to output the FP constant value in exponential notation,
|
|
|
|
// but we cannot do this if doing so will lose precision. Check here to
|
|
|
|
// make sure that we only output it in exponential format if we can parse
|
|
|
|
// the value back and get the same value.
|
|
|
|
bool isInf = apValue.isInfinity();
|
|
|
|
bool isNaN = apValue.isNaN();
|
|
|
|
if (!isInf && !isNaN) {
|
|
|
|
SmallString<128> strValue;
|
|
|
|
apValue.toString(strValue, 6, 0, false);
|
|
|
|
|
|
|
|
// Check to make sure that the stringized number is not some string like
|
|
|
|
// "Inf" or NaN, that atof will accept, but the lexer will not. Check
|
|
|
|
// that the string matches the "[-+]?[0-9]" regex.
|
|
|
|
assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
|
|
|
|
((strValue[0] == '-' || strValue[0] == '+') &&
|
|
|
|
(strValue[1] >= '0' && strValue[1] <= '9'))) &&
|
|
|
|
"[-+]?[0-9] regex does not match!");
|
|
|
|
// Reparse stringized version!
|
2018-10-21 09:31:49 +08:00
|
|
|
if (APFloat(APFloat::IEEEdouble(), strValue).bitwiseIsEqual(apValue)) {
|
2018-08-16 00:09:54 +08:00
|
|
|
os << strValue;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-10-21 09:31:49 +08:00
|
|
|
SmallVector<char, 16> str;
|
|
|
|
apValue.toString(str);
|
|
|
|
os << str;
|
2018-08-16 00:09:54 +08:00
|
|
|
}
|
|
|
|
|
2018-08-22 08:55:22 +08:00
|
|
|
void ModulePrinter::printFunctionReference(const Function *func) {
|
|
|
|
os << '@' << func->getName();
|
|
|
|
}
|
|
|
|
|
2018-07-25 07:07:22 +08:00
|
|
|
void ModulePrinter::printAttribute(const Attribute *attr) {
|
2018-07-19 07:29:21 +08:00
|
|
|
switch (attr->getKind()) {
|
|
|
|
case Attribute::Kind::Bool:
|
|
|
|
os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
|
|
|
|
break;
|
|
|
|
case Attribute::Kind::Integer:
|
|
|
|
os << cast<IntegerAttr>(attr)->getValue();
|
|
|
|
break;
|
|
|
|
case Attribute::Kind::Float:
|
2018-08-16 00:09:54 +08:00
|
|
|
printFloatValue(cast<FloatAttr>(attr)->getValue(), os);
|
2018-07-19 07:29:21 +08:00
|
|
|
break;
|
|
|
|
case Attribute::Kind::String:
|
2018-08-16 00:09:54 +08:00
|
|
|
os << '"';
|
|
|
|
printEscapedString(cast<StringAttr>(attr)->getValue(), os);
|
|
|
|
os << '"';
|
2018-07-19 07:29:21 +08:00
|
|
|
break;
|
2018-08-16 00:09:54 +08:00
|
|
|
case Attribute::Kind::Array:
|
2018-07-19 07:29:21 +08:00
|
|
|
os << '[';
|
2018-08-16 00:09:54 +08:00
|
|
|
interleaveComma(cast<ArrayAttr>(attr)->getValue(),
|
|
|
|
[&](Attribute *attr) { printAttribute(attr); });
|
2018-07-19 07:29:21 +08:00
|
|
|
os << ']';
|
|
|
|
break;
|
|
|
|
case Attribute::Kind::AffineMap:
|
|
|
|
printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
|
2018-08-03 16:54:46 +08:00
|
|
|
break;
|
|
|
|
case Attribute::Kind::Type:
|
|
|
|
printType(cast<TypeAttr>(attr)->getValue());
|
2018-07-19 07:29:21 +08:00
|
|
|
break;
|
2018-08-20 12:17:22 +08:00
|
|
|
case Attribute::Kind::Function: {
|
|
|
|
auto *function = cast<FunctionAttr>(attr)->getValue();
|
|
|
|
if (!function) {
|
|
|
|
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
|
|
|
|
} else {
|
2018-08-22 08:55:22 +08:00
|
|
|
printFunctionReference(function);
|
|
|
|
os << " : ";
|
2018-08-20 12:17:22 +08:00
|
|
|
printType(function->getType());
|
|
|
|
}
|
|
|
|
break;
|
|
|
|
}
|
2018-10-19 04:54:44 +08:00
|
|
|
case Attribute::Kind::DenseIntElements:
|
|
|
|
case Attribute::Kind::DenseFPElements: {
|
|
|
|
auto *eltsAttr = cast<DenseElementsAttr>(attr);
|
|
|
|
os << "dense<";
|
|
|
|
printType(eltsAttr->getType());
|
|
|
|
os << ", ";
|
|
|
|
printDenseElementsAttr(eltsAttr);
|
|
|
|
os << '>';
|
|
|
|
break;
|
|
|
|
}
|
2018-10-10 23:57:51 +08:00
|
|
|
case Attribute::Kind::SplatElements: {
|
|
|
|
auto *elementsAttr = cast<SplatElementsAttr>(attr);
|
|
|
|
os << "splat<";
|
|
|
|
printType(elementsAttr->getType());
|
|
|
|
os << ", ";
|
|
|
|
printAttribute(elementsAttr->getValue());
|
|
|
|
os << '>';
|
|
|
|
break;
|
|
|
|
}
|
Add support to constant sparse tensor / vector attribute
The SparseElementsAttr uses (COO) Coordinate List encoding to represents a
sparse tensor / vector. Specifically, the coordinates and values are stored as
two dense elements attributes. The first dense elements attribute is a 2-D
attribute with shape [N, ndims], which contains the indices of the elements
with nonzero values in the constant vector/tensor. The second elements
attribute is a 1-D attribute list with shape [N], which supplies the values for
each element in the first elements attribute. ndims is the rank of the
vector/tensor and N is the total nonzero elements.
The syntax is:
`sparse<` (tensor-type | vector-type)`, ` indices-attribute-list, values-attribute-list `>`
Example: a sparse tensor
sparse<vector<3x4xi32>, [[0, 0], [1, 2]], [1, 2]> represents the dense tensor
[[1, 0, 0, 0]
[0, 0, 2, 0]
[0, 0, 0, 0]]
PiperOrigin-RevId: 217764319
2018-10-19 05:02:20 +08:00
|
|
|
case Attribute::Kind::SparseElements: {
|
|
|
|
auto *elementsAttr = cast<SparseElementsAttr>(attr);
|
|
|
|
os << "sparse<";
|
|
|
|
printType(elementsAttr->getType());
|
|
|
|
os << ", ";
|
|
|
|
printDenseElementsAttr(elementsAttr->getIndices());
|
|
|
|
os << ", ";
|
|
|
|
printDenseElementsAttr(elementsAttr->getValues());
|
|
|
|
os << '>';
|
|
|
|
break;
|
|
|
|
}
|
2018-07-19 07:29:21 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-10-19 04:54:44 +08:00
|
|
|
void ModulePrinter::printDenseElementsAttr(const DenseElementsAttr *attr) {
|
|
|
|
auto *type = attr->getType();
|
|
|
|
auto shape = type->getShape();
|
|
|
|
auto rank = type->getRank();
|
|
|
|
|
|
|
|
SmallVector<Attribute *, 16> elements;
|
|
|
|
attr->getValues(elements);
|
|
|
|
|
|
|
|
// Special case for degenerate tensors.
|
|
|
|
if (elements.empty()) {
|
|
|
|
for (int i = 0; i < rank; ++i)
|
|
|
|
os << '[';
|
|
|
|
for (int i = 0; i < rank; ++i)
|
|
|
|
os << ']';
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// We use a mixed-radix counter to iterate through the shape. When we bump a
|
|
|
|
// non-least-significant digit, we emit a close bracket. When we next emit an
|
|
|
|
// element we re-open all closed brackets.
|
|
|
|
|
|
|
|
// The mixed-radix counter, with radices in 'shape'.
|
|
|
|
SmallVector<unsigned, 4> counter(rank, 0);
|
|
|
|
// The number of brackets that have been opened and not closed.
|
|
|
|
unsigned openBrackets = 0;
|
|
|
|
|
|
|
|
auto bumpCounter = [&]() {
|
|
|
|
// Bump the least significant digit.
|
|
|
|
++counter[rank - 1];
|
|
|
|
// Iterate backwards bubbling back the increment.
|
|
|
|
for (unsigned i = rank - 1; i > 0; --i)
|
|
|
|
if (counter[i] >= shape[i]) {
|
|
|
|
// Index 'i' is rolled over. Bump (i-1) and close a bracket.
|
|
|
|
counter[i] = 0;
|
|
|
|
++counter[i - 1];
|
|
|
|
--openBrackets;
|
|
|
|
os << ']';
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
for (unsigned idx = 0, e = elements.size(); idx != e; ++idx) {
|
|
|
|
if (idx != 0)
|
|
|
|
os << ", ";
|
|
|
|
while (openBrackets++ < rank)
|
|
|
|
os << '[';
|
|
|
|
openBrackets = rank;
|
|
|
|
printAttribute(elements[idx]);
|
|
|
|
bumpCounter();
|
|
|
|
}
|
|
|
|
while (openBrackets-- > 0)
|
|
|
|
os << ']';
|
|
|
|
}
|
|
|
|
|
2018-07-25 07:07:22 +08:00
|
|
|
void ModulePrinter::printType(const Type *type) {
|
2018-07-18 07:56:54 +08:00
|
|
|
switch (type->getKind()) {
|
2018-10-07 08:21:53 +08:00
|
|
|
case Type::Kind::Index:
|
|
|
|
os << "index";
|
2018-07-19 01:16:05 +08:00
|
|
|
return;
|
|
|
|
case Type::Kind::BF16:
|
|
|
|
os << "bf16";
|
|
|
|
return;
|
|
|
|
case Type::Kind::F16:
|
|
|
|
os << "f16";
|
|
|
|
return;
|
|
|
|
case Type::Kind::F32:
|
|
|
|
os << "f32";
|
|
|
|
return;
|
|
|
|
case Type::Kind::F64:
|
|
|
|
os << "f64";
|
|
|
|
return;
|
2018-07-28 02:07:12 +08:00
|
|
|
case Type::Kind::TFControl:
|
|
|
|
os << "tf_control";
|
|
|
|
return;
|
2018-09-20 01:28:46 +08:00
|
|
|
case Type::Kind::TFResource:
|
|
|
|
os << "tf_resource";
|
|
|
|
return;
|
2018-09-20 12:15:43 +08:00
|
|
|
case Type::Kind::TFVariant:
|
|
|
|
os << "tf_variant";
|
|
|
|
return;
|
2018-09-21 02:59:17 +08:00
|
|
|
case Type::Kind::TFComplex64:
|
|
|
|
os << "tf_complex64";
|
|
|
|
return;
|
|
|
|
case Type::Kind::TFComplex128:
|
|
|
|
os << "tf_complex128";
|
|
|
|
return;
|
2018-09-27 13:10:45 +08:00
|
|
|
case Type::Kind::TFF32REF:
|
|
|
|
os << "tf_f32ref";
|
|
|
|
return;
|
2018-08-02 03:55:27 +08:00
|
|
|
case Type::Kind::TFString:
|
|
|
|
os << "tf_string";
|
|
|
|
return;
|
2018-07-18 07:56:54 +08:00
|
|
|
|
|
|
|
case Type::Kind::Integer: {
|
|
|
|
auto *integer = cast<IntegerType>(type);
|
|
|
|
os << 'i' << integer->getWidth();
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
case Type::Kind::Function: {
|
|
|
|
auto *func = cast<FunctionType>(type);
|
|
|
|
os << '(';
|
2018-07-26 03:55:50 +08:00
|
|
|
interleaveComma(func->getInputs(), [&](Type *type) { printType(type); });
|
2018-07-18 07:56:54 +08:00
|
|
|
os << ") -> ";
|
|
|
|
auto results = func->getResults();
|
|
|
|
if (results.size() == 1)
|
|
|
|
os << *results[0];
|
|
|
|
else {
|
|
|
|
os << '(';
|
2018-07-26 03:55:50 +08:00
|
|
|
interleaveComma(results, [&](Type *type) { printType(type); });
|
2018-07-18 07:56:54 +08:00
|
|
|
os << ')';
|
|
|
|
}
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
case Type::Kind::Vector: {
|
|
|
|
auto *v = cast<VectorType>(type);
|
|
|
|
os << "vector<";
|
2018-07-24 02:44:40 +08:00
|
|
|
for (auto dim : v->getShape())
|
|
|
|
os << dim << 'x';
|
2018-07-18 07:56:54 +08:00
|
|
|
os << *v->getElementType() << '>';
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
case Type::Kind::RankedTensor: {
|
|
|
|
auto *v = cast<RankedTensorType>(type);
|
|
|
|
os << "tensor<";
|
|
|
|
for (auto dim : v->getShape()) {
|
|
|
|
if (dim < 0)
|
|
|
|
os << '?';
|
|
|
|
else
|
|
|
|
os << dim;
|
|
|
|
os << 'x';
|
|
|
|
}
|
|
|
|
os << *v->getElementType() << '>';
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
case Type::Kind::UnrankedTensor: {
|
|
|
|
auto *v = cast<UnrankedTensorType>(type);
|
2018-09-14 01:43:35 +08:00
|
|
|
os << "tensor<*x";
|
2018-07-26 03:55:50 +08:00
|
|
|
printType(v->getElementType());
|
|
|
|
os << '>';
|
2018-07-18 07:56:54 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
case Type::Kind::MemRef: {
|
|
|
|
auto *v = cast<MemRefType>(type);
|
|
|
|
os << "memref<";
|
|
|
|
for (auto dim : v->getShape()) {
|
|
|
|
if (dim < 0)
|
|
|
|
os << '?';
|
|
|
|
else
|
|
|
|
os << dim;
|
|
|
|
os << 'x';
|
|
|
|
}
|
2018-07-26 03:55:50 +08:00
|
|
|
printType(v->getElementType());
|
2018-07-18 07:56:54 +08:00
|
|
|
for (auto map : v->getAffineMaps()) {
|
|
|
|
os << ", ";
|
2018-07-19 07:29:21 +08:00
|
|
|
printAffineMapReference(map);
|
2018-07-18 07:56:54 +08:00
|
|
|
}
|
2018-07-26 03:55:50 +08:00
|
|
|
// Only print the memory space if it is the non-default one.
|
|
|
|
if (v->getMemorySpace())
|
|
|
|
os << ", " << v->getMemorySpace();
|
2018-07-18 07:56:54 +08:00
|
|
|
os << '>';
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Affine expressions and maps
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-10-09 04:47:18 +08:00
|
|
|
void ModulePrinter::printAffineExpr(AffineExpr expr) {
|
2018-08-01 07:21:36 +08:00
|
|
|
printAffineExprInternal(expr, BindingStrength::Weak);
|
|
|
|
}
|
|
|
|
|
|
|
|
void ModulePrinter::printAffineExprInternal(
|
2018-10-09 04:47:18 +08:00
|
|
|
AffineExpr expr, BindingStrength enclosingTightness) {
|
2018-08-01 07:21:36 +08:00
|
|
|
const char *binopSpelling = nullptr;
|
2018-10-10 01:59:27 +08:00
|
|
|
switch (expr.getKind()) {
|
2018-10-09 01:20:25 +08:00
|
|
|
case AffineExprKind::SymbolId:
|
2018-10-10 01:59:27 +08:00
|
|
|
os << 's' << expr.cast<AffineSymbolExpr>().getPosition();
|
2018-07-21 00:35:47 +08:00
|
|
|
return;
|
2018-10-09 01:20:25 +08:00
|
|
|
case AffineExprKind::DimId:
|
2018-10-10 01:59:27 +08:00
|
|
|
os << 'd' << expr.cast<AffineDimExpr>().getPosition();
|
2018-07-21 00:35:47 +08:00
|
|
|
return;
|
2018-10-09 01:20:25 +08:00
|
|
|
case AffineExprKind::Constant:
|
2018-10-10 01:59:27 +08:00
|
|
|
os << expr.cast<AffineConstantExpr>().getValue();
|
2018-07-21 00:35:47 +08:00
|
|
|
return;
|
2018-10-09 01:20:25 +08:00
|
|
|
case AffineExprKind::Add:
|
2018-08-01 07:21:36 +08:00
|
|
|
binopSpelling = " + ";
|
|
|
|
break;
|
2018-10-09 01:20:25 +08:00
|
|
|
case AffineExprKind::Mul:
|
2018-08-01 07:21:36 +08:00
|
|
|
binopSpelling = " * ";
|
|
|
|
break;
|
2018-10-09 01:20:25 +08:00
|
|
|
case AffineExprKind::FloorDiv:
|
2018-08-01 07:21:36 +08:00
|
|
|
binopSpelling = " floordiv ";
|
|
|
|
break;
|
2018-10-09 01:20:25 +08:00
|
|
|
case AffineExprKind::CeilDiv:
|
2018-08-01 07:21:36 +08:00
|
|
|
binopSpelling = " ceildiv ";
|
|
|
|
break;
|
2018-10-09 01:20:25 +08:00
|
|
|
case AffineExprKind::Mod:
|
2018-08-01 07:21:36 +08:00
|
|
|
binopSpelling = " mod ";
|
|
|
|
break;
|
2018-07-21 00:35:47 +08:00
|
|
|
}
|
|
|
|
|
2018-10-09 04:47:18 +08:00
|
|
|
auto binOp = expr.cast<AffineBinaryOpExpr>();
|
2018-07-21 00:35:47 +08:00
|
|
|
|
2018-08-01 07:21:36 +08:00
|
|
|
// Handle tightly binding binary operators.
|
2018-10-10 01:59:27 +08:00
|
|
|
if (binOp.getKind() != AffineExprKind::Add) {
|
2018-08-01 07:21:36 +08:00
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
|
|
os << '(';
|
|
|
|
|
2018-10-10 01:59:27 +08:00
|
|
|
printAffineExprInternal(binOp.getLHS(), BindingStrength::Strong);
|
2018-08-01 07:21:36 +08:00
|
|
|
os << binopSpelling;
|
2018-10-10 01:59:27 +08:00
|
|
|
printAffineExprInternal(binOp.getRHS(), BindingStrength::Strong);
|
2018-08-01 07:21:36 +08:00
|
|
|
|
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
|
|
os << ')';
|
2018-07-21 00:35:47 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Print out special "pretty" forms for add.
|
2018-08-01 07:21:36 +08:00
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
|
|
os << '(';
|
2018-07-21 00:35:47 +08:00
|
|
|
|
|
|
|
// Pretty print addition to a product that has a negative operand as a
|
|
|
|
// subtraction.
|
2018-10-10 01:59:27 +08:00
|
|
|
AffineExpr rhsExpr = binOp.getRHS();
|
2018-10-09 04:47:18 +08:00
|
|
|
if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
|
2018-10-10 01:59:27 +08:00
|
|
|
if (rhs.getKind() == AffineExprKind::Mul) {
|
|
|
|
AffineExpr rrhsExpr = rhs.getRHS();
|
2018-10-09 04:47:18 +08:00
|
|
|
if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
|
2018-10-10 01:59:27 +08:00
|
|
|
if (rrhs.getValue() == -1) {
|
|
|
|
printAffineExprInternal(binOp.getLHS(), BindingStrength::Weak);
|
2018-08-01 07:21:36 +08:00
|
|
|
os << " - ";
|
2018-10-19 09:18:04 +08:00
|
|
|
if (rhs.getLHS().getKind() == AffineExprKind::Add) {
|
|
|
|
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong);
|
|
|
|
} else {
|
|
|
|
printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak);
|
|
|
|
}
|
2018-08-01 07:21:36 +08:00
|
|
|
|
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
|
|
os << ')';
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2018-10-10 01:59:27 +08:00
|
|
|
if (rrhs.getValue() < -1) {
|
|
|
|
printAffineExprInternal(binOp.getLHS(), BindingStrength::Weak);
|
2018-08-02 13:02:00 +08:00
|
|
|
os << " - ";
|
2018-10-10 01:59:27 +08:00
|
|
|
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong);
|
|
|
|
os << " * " << -rrhs.getValue();
|
2018-08-01 07:21:36 +08:00
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
|
|
os << ')';
|
2018-07-21 00:35:47 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Pretty print addition to a negative number as a subtraction.
|
2018-10-09 04:47:18 +08:00
|
|
|
if (auto rhs = rhsExpr.dyn_cast<AffineConstantExpr>()) {
|
2018-10-10 01:59:27 +08:00
|
|
|
if (rhs.getValue() < 0) {
|
|
|
|
printAffineExprInternal(binOp.getLHS(), BindingStrength::Weak);
|
|
|
|
os << " - " << -rhs.getValue();
|
2018-08-08 05:24:38 +08:00
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
|
|
os << ')';
|
2018-07-21 00:35:47 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-10-10 01:59:27 +08:00
|
|
|
printAffineExprInternal(binOp.getLHS(), BindingStrength::Weak);
|
2018-07-21 00:35:47 +08:00
|
|
|
os << " + ";
|
2018-10-10 01:59:27 +08:00
|
|
|
printAffineExprInternal(binOp.getRHS(), BindingStrength::Weak);
|
2018-08-01 07:21:36 +08:00
|
|
|
|
|
|
|
if (enclosingTightness == BindingStrength::Strong)
|
|
|
|
os << ')';
|
2018-07-21 00:35:47 +08:00
|
|
|
}
|
|
|
|
|
2018-10-09 04:47:18 +08:00
|
|
|
void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
|
2018-08-08 05:24:38 +08:00
|
|
|
printAffineExprInternal(expr, BindingStrength::Weak);
|
|
|
|
isEq ? os << " == 0" : os << " >= 0";
|
|
|
|
}
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
void ModulePrinter::printAffineMap(AffineMap map) {
|
2018-07-21 00:35:47 +08:00
|
|
|
// Dimension identifiers.
|
|
|
|
os << '(';
|
2018-10-10 07:39:24 +08:00
|
|
|
for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
|
2018-07-30 05:13:03 +08:00
|
|
|
os << 'd' << i << ", ";
|
2018-10-10 07:39:24 +08:00
|
|
|
if (map.getNumDims() >= 1)
|
|
|
|
os << 'd' << map.getNumDims() - 1;
|
2018-07-30 05:13:03 +08:00
|
|
|
os << ')';
|
2018-07-21 00:35:47 +08:00
|
|
|
|
|
|
|
// Symbolic identifiers.
|
2018-10-10 07:39:24 +08:00
|
|
|
if (map.getNumSymbols() != 0) {
|
2018-07-30 05:13:03 +08:00
|
|
|
os << '[';
|
2018-10-10 07:39:24 +08:00
|
|
|
for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
|
2018-07-30 05:13:03 +08:00
|
|
|
os << 's' << i << ", ";
|
2018-10-10 07:39:24 +08:00
|
|
|
if (map.getNumSymbols() >= 1)
|
|
|
|
os << 's' << map.getNumSymbols() - 1;
|
2018-07-30 05:13:03 +08:00
|
|
|
os << ']';
|
2018-07-21 00:35:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// AffineMap should have at least one result.
|
2018-10-10 07:39:24 +08:00
|
|
|
assert(!map.getResults().empty());
|
2018-07-21 00:35:47 +08:00
|
|
|
// Result affine expressions.
|
|
|
|
os << " -> (";
|
2018-10-10 07:39:24 +08:00
|
|
|
interleaveComma(map.getResults(),
|
2018-10-09 04:47:18 +08:00
|
|
|
[&](AffineExpr expr) { printAffineExpr(expr); });
|
2018-07-30 05:13:03 +08:00
|
|
|
os << ')';
|
2018-07-21 00:35:47 +08:00
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
if (!map.isBounded()) {
|
2018-07-21 00:35:47 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Print range sizes for bounded affine maps.
|
|
|
|
os << " size (";
|
2018-10-10 07:39:24 +08:00
|
|
|
interleaveComma(map.getRangeSizes(),
|
2018-10-09 04:47:18 +08:00
|
|
|
[&](AffineExpr expr) { printAffineExpr(expr); });
|
2018-07-30 05:13:03 +08:00
|
|
|
os << ')';
|
2018-07-21 00:35:47 +08:00
|
|
|
}
|
|
|
|
|
2018-10-11 00:45:59 +08:00
|
|
|
void ModulePrinter::printIntegerSet(IntegerSet set) {
|
2018-08-08 05:24:38 +08:00
|
|
|
// Dimension identifiers.
|
|
|
|
os << '(';
|
2018-10-11 00:45:59 +08:00
|
|
|
for (unsigned i = 1; i < set.getNumDims(); ++i)
|
2018-08-08 05:24:38 +08:00
|
|
|
os << 'd' << i - 1 << ", ";
|
2018-10-11 00:45:59 +08:00
|
|
|
if (set.getNumDims() >= 1)
|
|
|
|
os << 'd' << set.getNumDims() - 1;
|
2018-08-08 05:24:38 +08:00
|
|
|
os << ')';
|
|
|
|
|
|
|
|
// Symbolic identifiers.
|
2018-10-11 00:45:59 +08:00
|
|
|
if (set.getNumSymbols() != 0) {
|
2018-08-08 05:24:38 +08:00
|
|
|
os << '[';
|
2018-10-11 00:45:59 +08:00
|
|
|
for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
|
2018-08-08 05:24:38 +08:00
|
|
|
os << 's' << i << ", ";
|
2018-10-11 00:45:59 +08:00
|
|
|
if (set.getNumSymbols() >= 1)
|
|
|
|
os << 's' << set.getNumSymbols() - 1;
|
2018-08-08 05:24:38 +08:00
|
|
|
os << ']';
|
|
|
|
}
|
|
|
|
|
|
|
|
// Print constraints.
|
|
|
|
os << " : (";
|
2018-10-11 00:45:59 +08:00
|
|
|
auto numConstraints = set.getNumConstraints();
|
2018-08-08 05:24:38 +08:00
|
|
|
for (int i = 1; i < numConstraints; ++i) {
|
2018-10-11 00:45:59 +08:00
|
|
|
printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
|
2018-08-08 05:24:38 +08:00
|
|
|
os << ", ";
|
|
|
|
}
|
|
|
|
if (numConstraints >= 1)
|
2018-10-11 00:45:59 +08:00
|
|
|
printAffineConstraint(set.getConstraint(numConstraints - 1),
|
|
|
|
set.isEq(numConstraints - 1));
|
2018-08-08 05:24:38 +08:00
|
|
|
os << ')';
|
|
|
|
}
|
|
|
|
|
2018-06-24 07:03:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Function printing
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-08-07 02:54:39 +08:00
|
|
|
void ModulePrinter::printFunctionResultType(const FunctionType *type) {
|
2018-06-24 07:03:42 +08:00
|
|
|
switch (type->getResults().size()) {
|
2018-07-19 01:16:05 +08:00
|
|
|
case 0:
|
|
|
|
break;
|
2018-06-24 07:03:42 +08:00
|
|
|
case 1:
|
2018-07-18 07:56:54 +08:00
|
|
|
os << " -> ";
|
2018-07-25 07:07:22 +08:00
|
|
|
printType(type->getResults()[0]);
|
2018-06-24 07:03:42 +08:00
|
|
|
break;
|
|
|
|
default:
|
|
|
|
os << " -> (";
|
2018-07-25 07:07:22 +08:00
|
|
|
interleaveComma(type->getResults(),
|
|
|
|
[&](Type *eltType) { printType(eltType); });
|
2018-06-24 07:03:42 +08:00
|
|
|
os << ')';
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-09-19 07:36:26 +08:00
|
|
|
void ModulePrinter::printFunctionAttributes(const Function *fn) {
|
|
|
|
auto attrs = fn->getAttrs();
|
|
|
|
if (attrs.empty())
|
|
|
|
return;
|
|
|
|
os << "\n attributes ";
|
|
|
|
printOptionalAttrDict(attrs);
|
|
|
|
}
|
|
|
|
|
2018-08-07 02:54:39 +08:00
|
|
|
void ModulePrinter::printFunctionSignature(const Function *fn) {
|
|
|
|
auto type = fn->getType();
|
|
|
|
|
|
|
|
os << "@" << fn->getName() << '(';
|
|
|
|
interleaveComma(type->getInputs(),
|
|
|
|
[&](Type *eltType) { printType(eltType); });
|
|
|
|
os << ')';
|
|
|
|
|
|
|
|
printFunctionResultType(type);
|
|
|
|
}
|
|
|
|
|
2018-09-19 07:36:26 +08:00
|
|
|
void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
|
|
|
ArrayRef<const char *> elidedAttrs) {
|
|
|
|
// If there are no attributes, then there is nothing to be done.
|
|
|
|
if (attrs.empty())
|
|
|
|
return;
|
|
|
|
|
|
|
|
// Filter out any attributes that shouldn't be included.
|
|
|
|
SmallVector<NamedAttribute, 8> filteredAttrs;
|
|
|
|
for (auto attr : attrs) {
|
|
|
|
auto attrName = attr.first.strref();
|
|
|
|
// Never print attributes that start with a colon. These are internal
|
|
|
|
// attributes that represent location or other internal metadata.
|
|
|
|
if (attrName.startswith(":"))
|
|
|
|
return;
|
|
|
|
|
|
|
|
// If the caller has requested that this attribute be ignored, then drop it.
|
|
|
|
bool ignore = false;
|
|
|
|
for (const char *elide : elidedAttrs)
|
|
|
|
ignore |= attrName == StringRef(elide);
|
|
|
|
|
|
|
|
// Otherwise add it to our filteredAttrs list.
|
|
|
|
if (!ignore) {
|
|
|
|
filteredAttrs.push_back(attr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// If there are no attributes left to print after filtering, then we're done.
|
|
|
|
if (filteredAttrs.empty())
|
|
|
|
return;
|
|
|
|
|
|
|
|
// Otherwise, print them all out in braces.
|
|
|
|
os << " {";
|
|
|
|
interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
|
|
|
|
os << attr.first << ": ";
|
|
|
|
printAttribute(attr.second);
|
|
|
|
});
|
|
|
|
os << '}';
|
|
|
|
}
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void ModulePrinter::print(const ExtFunction *fn) {
|
2018-06-24 07:03:42 +08:00
|
|
|
os << "extfunc ";
|
2018-07-21 00:35:47 +08:00
|
|
|
printFunctionSignature(fn);
|
2018-09-19 07:36:26 +08:00
|
|
|
printFunctionAttributes(fn);
|
2018-07-19 01:16:05 +08:00
|
|
|
os << '\n';
|
2018-06-24 07:03:42 +08:00
|
|
|
}
|
|
|
|
|
2018-07-10 08:42:46 +08:00
|
|
|
namespace {
|
|
|
|
|
2018-07-25 07:07:22 +08:00
|
|
|
// FunctionPrinter contains common functionality for printing
|
2018-07-10 08:42:46 +08:00
|
|
|
// CFG and ML functions.
|
2018-07-25 07:07:22 +08:00
|
|
|
class FunctionPrinter : public ModulePrinter, private OpAsmPrinter {
|
2018-07-10 08:42:46 +08:00
|
|
|
public:
|
2018-07-25 07:07:22 +08:00
|
|
|
FunctionPrinter(const ModulePrinter &other) : ModulePrinter(other) {}
|
2018-07-10 08:42:46 +08:00
|
|
|
|
|
|
|
void printOperation(const Operation *op);
|
2018-07-25 07:07:22 +08:00
|
|
|
void printDefaultOp(const Operation *op);
|
|
|
|
|
|
|
|
// Implement OpAsmPrinter.
|
|
|
|
raw_ostream &getStream() const { return os; }
|
|
|
|
void printType(const Type *type) { ModulePrinter::printType(type); }
|
|
|
|
void printAttribute(const Attribute *attr) {
|
|
|
|
ModulePrinter::printAttribute(attr);
|
|
|
|
}
|
2018-10-10 07:39:24 +08:00
|
|
|
void printAffineMap(AffineMap map) {
|
2018-07-29 00:36:25 +08:00
|
|
|
return ModulePrinter::printAffineMapReference(map);
|
2018-07-25 07:07:22 +08:00
|
|
|
}
|
2018-10-11 00:45:59 +08:00
|
|
|
void printIntegerSet(IntegerSet set) {
|
2018-08-08 05:24:38 +08:00
|
|
|
return ModulePrinter::printIntegerSetReference(set);
|
|
|
|
}
|
2018-10-09 04:47:18 +08:00
|
|
|
void printAffineExpr(AffineExpr expr) {
|
2018-07-25 07:07:22 +08:00
|
|
|
return ModulePrinter::printAffineExpr(expr);
|
|
|
|
}
|
2018-08-22 08:55:22 +08:00
|
|
|
void printFunctionReference(const Function *func) {
|
|
|
|
return ModulePrinter::printFunctionReference(func);
|
|
|
|
}
|
2018-09-19 07:36:26 +08:00
|
|
|
void printFunctionAttributes(const Function *func) {
|
|
|
|
return ModulePrinter::printFunctionAttributes(func);
|
|
|
|
}
|
2018-07-25 07:07:22 +08:00
|
|
|
void printOperand(const SSAValue *value) { printValueID(value); }
|
2018-08-04 02:12:34 +08:00
|
|
|
|
2018-08-03 07:54:36 +08:00
|
|
|
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
|
2018-09-19 07:36:26 +08:00
|
|
|
ArrayRef<const char *> elidedAttrs = {}) {
|
|
|
|
return ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
|
|
|
|
};
|
2018-07-10 08:42:46 +08:00
|
|
|
|
2018-08-02 01:43:18 +08:00
|
|
|
enum { nameSentinel = ~0U };
|
|
|
|
|
2018-07-10 08:42:46 +08:00
|
|
|
protected:
|
2018-07-21 00:28:54 +08:00
|
|
|
void numberValueID(const SSAValue *value) {
|
|
|
|
assert(!valueIDs.count(value) && "Value numbered multiple times");
|
2018-08-02 01:43:18 +08:00
|
|
|
|
|
|
|
SmallString<32> specialNameBuffer;
|
|
|
|
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
|
|
|
|
|
|
|
// Give constant integers special names.
|
|
|
|
if (auto *op = value->getDefiningOperation()) {
|
2018-10-20 00:07:58 +08:00
|
|
|
if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
|
2018-08-03 08:16:58 +08:00
|
|
|
// i1 constants get special names.
|
|
|
|
if (intOp->getType()->isInteger(1)) {
|
|
|
|
specialName << (intOp->getValue() ? "true" : "false");
|
|
|
|
} else {
|
2018-08-08 03:02:37 +08:00
|
|
|
specialName << 'c' << intOp->getValue() << '_' << *intOp->getType();
|
2018-08-03 08:16:58 +08:00
|
|
|
}
|
2018-10-20 00:07:58 +08:00
|
|
|
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
|
2018-08-08 03:02:37 +08:00
|
|
|
specialName << 'c' << intOp->getValue();
|
2018-10-20 00:07:58 +08:00
|
|
|
} else if (auto constant = op->dyn_cast<ConstantOp>()) {
|
2018-08-20 12:17:22 +08:00
|
|
|
if (isa<FunctionAttr>(constant->getValue()))
|
|
|
|
specialName << 'f';
|
|
|
|
else
|
|
|
|
specialName << "cst";
|
2018-08-02 01:43:18 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (specialNameBuffer.empty()) {
|
|
|
|
switch (value->getKind()) {
|
|
|
|
case SSAValueKind::BBArgument:
|
|
|
|
// If this is an argument to the function, give it an 'arg' name.
|
|
|
|
if (auto *bb = cast<BBArgument>(value)->getOwner())
|
|
|
|
if (auto *fn = bb->getFunction())
|
|
|
|
if (&fn->front() == bb) {
|
|
|
|
specialName << "arg" << nextArgumentID++;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
// Otherwise number it normally.
|
|
|
|
LLVM_FALLTHROUGH;
|
|
|
|
case SSAValueKind::InstResult:
|
|
|
|
case SSAValueKind::StmtResult:
|
|
|
|
// This is an uninteresting result, give it a boring number and be
|
|
|
|
// done with it.
|
|
|
|
valueIDs[value] = nextValueID++;
|
|
|
|
return;
|
2018-08-07 02:54:39 +08:00
|
|
|
case SSAValueKind::MLFuncArgument:
|
2018-08-02 01:43:18 +08:00
|
|
|
specialName << "arg" << nextArgumentID++;
|
|
|
|
break;
|
|
|
|
case SSAValueKind::ForStmt:
|
|
|
|
specialName << 'i' << nextLoopID++;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Ok, this value had an interesting name. Remember it with a sentinel.
|
|
|
|
valueIDs[value] = nameSentinel;
|
|
|
|
|
|
|
|
// Remember that we've used this name, checking to see if we had a conflict.
|
|
|
|
auto insertRes = usedNames.insert(specialName.str());
|
|
|
|
if (insertRes.second) {
|
|
|
|
// If this is the first use of the name, then we're successful!
|
|
|
|
valueNames[value] = insertRes.first->first();
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Otherwise, we had a conflict - probe until we find a unique name. This
|
|
|
|
// is guaranteed to terminate (and usually in a single iteration) because it
|
|
|
|
// generates new names by incrementing nextConflictID.
|
|
|
|
while (1) {
|
|
|
|
std::string probeName =
|
|
|
|
specialName.str().str() + "_" + llvm::utostr(nextConflictID++);
|
|
|
|
insertRes = usedNames.insert(probeName);
|
|
|
|
if (insertRes.second) {
|
|
|
|
// If this is the first use of the name, then we're successful!
|
|
|
|
valueNames[value] = insertRes.first->first();
|
|
|
|
return;
|
|
|
|
}
|
2018-07-31 06:18:10 +08:00
|
|
|
}
|
2018-07-21 00:28:54 +08:00
|
|
|
}
|
|
|
|
|
2018-07-31 06:18:10 +08:00
|
|
|
void printValueID(const SSAValue *value, bool printResultNo = true) const {
|
2018-07-21 09:41:34 +08:00
|
|
|
int resultNo = -1;
|
|
|
|
auto lookupValue = value;
|
|
|
|
|
2018-07-31 06:18:10 +08:00
|
|
|
// If this is a reference to the result of a multi-result instruction or
|
|
|
|
// statement, print out the # identifier and make sure to map our lookup
|
|
|
|
// to the first result of the instruction.
|
2018-07-21 09:41:34 +08:00
|
|
|
if (auto *result = dyn_cast<InstResult>(value)) {
|
|
|
|
if (result->getOwner()->getNumResults() != 1) {
|
|
|
|
resultNo = result->getResultNumber();
|
|
|
|
lookupValue = result->getOwner()->getResult(0);
|
|
|
|
}
|
2018-07-31 06:18:10 +08:00
|
|
|
} else if (auto *result = dyn_cast<StmtResult>(value)) {
|
|
|
|
if (result->getOwner()->getNumResults() != 1) {
|
|
|
|
resultNo = result->getResultNumber();
|
|
|
|
lookupValue = result->getOwner()->getResult(0);
|
|
|
|
}
|
2018-07-21 09:41:34 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
auto it = valueIDs.find(lookupValue);
|
|
|
|
if (it == valueIDs.end()) {
|
2018-07-21 00:28:54 +08:00
|
|
|
os << "<<INVALID SSA VALUE>>";
|
2018-07-21 09:41:34 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2018-07-31 06:18:10 +08:00
|
|
|
os << '%';
|
2018-08-02 01:43:18 +08:00
|
|
|
if (it->second != nameSentinel) {
|
|
|
|
os << it->second;
|
|
|
|
} else {
|
|
|
|
auto nameIt = valueNames.find(lookupValue);
|
|
|
|
assert(nameIt != valueNames.end() && "Didn't have a name entry?");
|
|
|
|
os << nameIt->second;
|
|
|
|
}
|
2018-07-31 06:18:10 +08:00
|
|
|
|
|
|
|
if (resultNo != -1 && printResultNo)
|
2018-07-21 09:41:34 +08:00
|
|
|
os << '#' << resultNo;
|
2018-07-21 00:28:54 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
2018-08-02 01:43:18 +08:00
|
|
|
/// This is the value ID for each SSA value in the current function. If this
|
|
|
|
/// returns ~0, then the valueID has an entry in valueNames.
|
2018-07-21 00:28:54 +08:00
|
|
|
DenseMap<const SSAValue *, unsigned> valueIDs;
|
2018-08-02 01:43:18 +08:00
|
|
|
DenseMap<const SSAValue *, StringRef> valueNames;
|
|
|
|
|
|
|
|
/// This keeps track of all of the non-numeric names that are in flight,
|
|
|
|
/// allowing us to check for duplicates.
|
|
|
|
llvm::StringSet<> usedNames;
|
|
|
|
|
|
|
|
/// This is the next value ID to assign in numbering.
|
2018-07-21 00:28:54 +08:00
|
|
|
unsigned nextValueID = 0;
|
2018-08-02 01:43:18 +08:00
|
|
|
/// This is the ID to assign to the next induction variable.
|
2018-07-31 22:40:14 +08:00
|
|
|
unsigned nextLoopID = 0;
|
2018-08-02 01:43:18 +08:00
|
|
|
/// This is the next ID to assign to an MLFunction argument.
|
|
|
|
unsigned nextArgumentID = 0;
|
|
|
|
|
|
|
|
/// This is the next ID to assign when a name conflict is detected.
|
|
|
|
unsigned nextConflictID = 0;
|
2018-07-10 08:42:46 +08:00
|
|
|
};
|
2018-07-24 02:44:40 +08:00
|
|
|
} // end anonymous namespace
|
2018-07-10 08:42:46 +08:00
|
|
|
|
2018-07-25 07:07:22 +08:00
|
|
|
void FunctionPrinter::printOperation(const Operation *op) {
|
2018-07-23 12:02:26 +08:00
|
|
|
if (op->getNumResults()) {
|
2018-07-31 06:18:10 +08:00
|
|
|
printValueID(op->getResult(0), /*printResultNo=*/false);
|
2018-07-23 12:02:26 +08:00
|
|
|
os << " = ";
|
2018-07-21 00:28:54 +08:00
|
|
|
}
|
|
|
|
|
2018-07-10 08:42:46 +08:00
|
|
|
// Check to see if this is a known operation. If so, use the registered
|
|
|
|
// custom printer hook.
|
2018-10-10 13:08:52 +08:00
|
|
|
if (auto *opInfo = op->getAbstractOperation()) {
|
2018-07-25 07:07:22 +08:00
|
|
|
opInfo->printAssembly(op, this);
|
2018-07-10 08:42:46 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2018-07-21 00:28:54 +08:00
|
|
|
// Otherwise use the standard verbose printing approach.
|
2018-07-25 07:07:22 +08:00
|
|
|
printDefaultOp(op);
|
|
|
|
}
|
2018-07-21 00:28:54 +08:00
|
|
|
|
2018-07-25 07:07:22 +08:00
|
|
|
void FunctionPrinter::printDefaultOp(const Operation *op) {
|
2018-08-16 00:09:54 +08:00
|
|
|
os << '"';
|
2018-10-10 13:08:52 +08:00
|
|
|
printEscapedString(op->getName().getStringRef(), os);
|
2018-08-16 00:09:54 +08:00
|
|
|
os << "\"(";
|
2018-07-10 08:42:46 +08:00
|
|
|
|
2018-07-23 12:02:26 +08:00
|
|
|
interleaveComma(op->getOperands(),
|
|
|
|
[&](const SSAValue *value) { printValueID(value); });
|
2018-07-19 23:35:28 +08:00
|
|
|
|
2018-07-21 00:28:54 +08:00
|
|
|
os << ')';
|
2018-07-10 08:42:46 +08:00
|
|
|
auto attrs = op->getAttrs();
|
2018-08-03 07:54:36 +08:00
|
|
|
printOptionalAttrDict(attrs);
|
2018-07-19 06:31:25 +08:00
|
|
|
|
2018-07-23 12:02:26 +08:00
|
|
|
// Print the type signature of the operation.
|
|
|
|
os << " : (";
|
|
|
|
interleaveComma(op->getOperands(),
|
2018-07-25 07:07:22 +08:00
|
|
|
[&](const SSAValue *value) { printType(value->getType()); });
|
2018-07-23 12:02:26 +08:00
|
|
|
os << ") -> ";
|
2018-07-21 00:28:54 +08:00
|
|
|
|
2018-07-23 12:02:26 +08:00
|
|
|
if (op->getNumResults() == 1) {
|
2018-07-25 07:07:22 +08:00
|
|
|
printType(op->getResult(0)->getType());
|
2018-07-23 12:02:26 +08:00
|
|
|
} else {
|
|
|
|
os << '(';
|
2018-07-25 07:07:22 +08:00
|
|
|
interleaveComma(op->getResults(), [&](const SSAValue *result) {
|
|
|
|
printType(result->getType());
|
|
|
|
});
|
2018-07-23 12:02:26 +08:00
|
|
|
os << ')';
|
2018-07-21 00:28:54 +08:00
|
|
|
}
|
2018-07-10 08:42:46 +08:00
|
|
|
}
|
|
|
|
|
2018-06-24 07:03:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// CFG Function printing
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
namespace {
|
2018-07-25 07:07:22 +08:00
|
|
|
class CFGFunctionPrinter : public FunctionPrinter {
|
2018-06-24 07:03:42 +08:00
|
|
|
public:
|
2018-07-21 00:35:47 +08:00
|
|
|
CFGFunctionPrinter(const CFGFunction *function, const ModulePrinter &other);
|
2018-06-24 07:03:42 +08:00
|
|
|
|
|
|
|
const CFGFunction *getFunction() const { return function; }
|
|
|
|
|
|
|
|
void print();
|
|
|
|
void print(const BasicBlock *block);
|
2018-06-29 11:45:33 +08:00
|
|
|
|
|
|
|
void print(const Instruction *inst);
|
|
|
|
void print(const OperationInst *inst);
|
|
|
|
void print(const ReturnInst *inst);
|
|
|
|
void print(const BranchInst *inst);
|
2018-07-25 06:01:27 +08:00
|
|
|
void print(const CondBranchInst *inst);
|
2018-06-24 07:03:42 +08:00
|
|
|
|
2018-09-22 05:40:36 +08:00
|
|
|
void printBBName(const BasicBlock *block) { os << "bb" << getBBID(block); }
|
|
|
|
|
2018-06-24 07:03:42 +08:00
|
|
|
unsigned getBBID(const BasicBlock *block) {
|
|
|
|
auto it = basicBlockIDs.find(block);
|
|
|
|
assert(it != basicBlockIDs.end() && "Block not in this function?");
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
const CFGFunction *function;
|
2018-07-19 01:16:05 +08:00
|
|
|
DenseMap<const BasicBlock *, unsigned> basicBlockIDs;
|
2018-07-21 00:28:54 +08:00
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void numberValuesInBlock(const BasicBlock *block);
|
2018-06-24 07:03:42 +08:00
|
|
|
};
|
2018-07-24 02:44:40 +08:00
|
|
|
} // end anonymous namespace
|
2018-06-24 07:03:42 +08:00
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
CFGFunctionPrinter::CFGFunctionPrinter(const CFGFunction *function,
|
|
|
|
const ModulePrinter &other)
|
2018-07-25 07:07:22 +08:00
|
|
|
: FunctionPrinter(other), function(function) {
|
2018-06-24 07:03:42 +08:00
|
|
|
// Each basic block gets a unique ID per function.
|
|
|
|
unsigned blockID = 0;
|
2018-07-21 00:28:54 +08:00
|
|
|
for (auto &block : *function) {
|
|
|
|
basicBlockIDs[&block] = blockID++;
|
2018-07-21 00:35:47 +08:00
|
|
|
numberValuesInBlock(&block);
|
2018-07-21 00:28:54 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Number all of the SSA values in the specified basic block.
|
2018-07-21 00:35:47 +08:00
|
|
|
void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) {
|
2018-07-23 06:45:24 +08:00
|
|
|
for (auto *arg : block->getArguments()) {
|
|
|
|
numberValueID(arg);
|
|
|
|
}
|
2018-07-21 00:28:54 +08:00
|
|
|
for (auto &op : *block) {
|
|
|
|
// We number instruction that have results, and we only number the first
|
|
|
|
// result.
|
|
|
|
if (op.getNumResults() != 0)
|
|
|
|
numberValueID(op.getResult(0));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Terminators do not define values.
|
2018-06-24 07:03:42 +08:00
|
|
|
}
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void CFGFunctionPrinter::print() {
|
2018-06-24 07:03:42 +08:00
|
|
|
os << "cfgfunc ";
|
2018-07-21 00:35:47 +08:00
|
|
|
printFunctionSignature(getFunction());
|
2018-09-19 07:36:26 +08:00
|
|
|
printFunctionAttributes(getFunction());
|
2018-06-24 07:03:42 +08:00
|
|
|
os << " {\n";
|
|
|
|
|
2018-07-24 02:44:40 +08:00
|
|
|
for (auto &block : *function)
|
|
|
|
print(&block);
|
2018-06-24 07:03:42 +08:00
|
|
|
os << "}\n\n";
|
|
|
|
}
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void CFGFunctionPrinter::print(const BasicBlock *block) {
|
2018-09-22 05:40:36 +08:00
|
|
|
printBBName(block);
|
2018-07-23 06:45:24 +08:00
|
|
|
|
|
|
|
if (!block->args_empty()) {
|
|
|
|
os << '(';
|
|
|
|
interleaveComma(block->getArguments(), [&](const BBArgument *arg) {
|
|
|
|
printValueID(arg);
|
|
|
|
os << ": ";
|
2018-07-25 07:07:22 +08:00
|
|
|
printType(arg->getType());
|
2018-07-23 06:45:24 +08:00
|
|
|
});
|
|
|
|
os << ')';
|
|
|
|
}
|
2018-07-28 02:10:12 +08:00
|
|
|
os << ':';
|
|
|
|
|
|
|
|
// Print out some context information about the predecessors of this block.
|
|
|
|
if (!block->getFunction()) {
|
|
|
|
os << "\t// block is not in a function!";
|
|
|
|
} else if (block->hasNoPredecessors()) {
|
|
|
|
// Don't print "no predecessors" for the entry block.
|
|
|
|
if (block != &block->getFunction()->front())
|
|
|
|
os << "\t// no predecessors";
|
|
|
|
} else if (auto *pred = block->getSinglePredecessor()) {
|
2018-09-22 05:40:36 +08:00
|
|
|
os << "\t// pred: ";
|
|
|
|
printBBName(pred);
|
2018-07-28 02:10:12 +08:00
|
|
|
} else {
|
|
|
|
// We want to print the predecessors in increasing numeric order, not in
|
|
|
|
// whatever order the use-list is in, so gather and sort them.
|
|
|
|
SmallVector<unsigned, 4> predIDs;
|
|
|
|
for (auto *pred : block->getPredecessors())
|
|
|
|
predIDs.push_back(getBBID(pred));
|
|
|
|
llvm::array_pod_sort(predIDs.begin(), predIDs.end());
|
|
|
|
|
|
|
|
os << "\t// " << predIDs.size() << " preds: ";
|
|
|
|
|
|
|
|
interleaveComma(predIDs, [&](unsigned predID) { os << "bb" << predID; });
|
|
|
|
}
|
|
|
|
os << '\n';
|
2018-06-24 07:03:42 +08:00
|
|
|
|
2018-07-15 15:06:54 +08:00
|
|
|
for (auto &inst : block->getOperations()) {
|
2018-07-27 09:09:20 +08:00
|
|
|
os << " ";
|
2018-07-02 11:28:00 +08:00
|
|
|
print(&inst);
|
2018-07-23 06:45:24 +08:00
|
|
|
os << '\n';
|
2018-07-15 15:06:54 +08:00
|
|
|
}
|
2018-06-24 07:03:42 +08:00
|
|
|
|
2018-08-03 18:51:38 +08:00
|
|
|
os << " ";
|
2018-06-24 07:03:42 +08:00
|
|
|
print(block->getTerminator());
|
2018-07-23 06:45:24 +08:00
|
|
|
os << '\n';
|
2018-06-24 07:03:42 +08:00
|
|
|
}
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void CFGFunctionPrinter::print(const Instruction *inst) {
|
2018-09-07 00:17:08 +08:00
|
|
|
if (!inst) {
|
|
|
|
os << "<<null instruction>>\n";
|
|
|
|
return;
|
|
|
|
}
|
2018-06-24 07:03:42 +08:00
|
|
|
switch (inst->getKind()) {
|
2018-06-29 11:45:33 +08:00
|
|
|
case Instruction::Kind::Operation:
|
|
|
|
return print(cast<OperationInst>(inst));
|
2018-06-25 02:18:29 +08:00
|
|
|
case TerminatorInst::Kind::Branch:
|
2018-06-29 11:45:33 +08:00
|
|
|
return print(cast<BranchInst>(inst));
|
2018-07-25 06:01:27 +08:00
|
|
|
case TerminatorInst::Kind::CondBranch:
|
|
|
|
return print(cast<CondBranchInst>(inst));
|
2018-06-24 07:03:42 +08:00
|
|
|
case TerminatorInst::Kind::Return:
|
2018-06-29 11:45:33 +08:00
|
|
|
return print(cast<ReturnInst>(inst));
|
2018-06-24 07:03:42 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void CFGFunctionPrinter::print(const OperationInst *inst) {
|
2018-07-10 08:42:46 +08:00
|
|
|
printOperation(inst);
|
2018-07-19 06:31:25 +08:00
|
|
|
}
|
2018-07-23 23:42:19 +08:00
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void CFGFunctionPrinter::print(const BranchInst *inst) {
|
2018-09-22 05:40:36 +08:00
|
|
|
os << "br ";
|
|
|
|
printBBName(inst->getDest());
|
2018-07-23 23:42:19 +08:00
|
|
|
|
|
|
|
if (inst->getNumOperands() != 0) {
|
|
|
|
os << '(';
|
2018-08-16 00:09:54 +08:00
|
|
|
interleaveComma(inst->getOperands(),
|
|
|
|
[&](const CFGValue *operand) { printValueID(operand); });
|
2018-07-23 23:42:19 +08:00
|
|
|
os << ") : ";
|
2018-08-16 00:09:54 +08:00
|
|
|
interleaveComma(inst->getOperands(), [&](const CFGValue *operand) {
|
|
|
|
printType(operand->getType());
|
2018-07-23 23:42:19 +08:00
|
|
|
});
|
|
|
|
}
|
2018-06-29 11:45:33 +08:00
|
|
|
}
|
2018-07-23 23:42:19 +08:00
|
|
|
|
2018-07-25 06:01:27 +08:00
|
|
|
void CFGFunctionPrinter::print(const CondBranchInst *inst) {
|
2018-07-27 09:09:20 +08:00
|
|
|
os << "cond_br ";
|
2018-07-25 06:01:27 +08:00
|
|
|
printValueID(inst->getCondition());
|
|
|
|
|
2018-09-22 05:40:36 +08:00
|
|
|
os << ", ";
|
|
|
|
printBBName(inst->getTrueDest());
|
2018-07-25 06:01:27 +08:00
|
|
|
if (inst->getNumTrueOperands() != 0) {
|
|
|
|
os << '(';
|
|
|
|
interleaveComma(inst->getTrueOperands(),
|
|
|
|
[&](const CFGValue *operand) { printValueID(operand); });
|
|
|
|
os << " : ";
|
|
|
|
interleaveComma(inst->getTrueOperands(), [&](const CFGValue *operand) {
|
2018-07-25 07:07:22 +08:00
|
|
|
printType(operand->getType());
|
2018-07-25 06:01:27 +08:00
|
|
|
});
|
|
|
|
os << ")";
|
|
|
|
}
|
|
|
|
|
2018-09-22 05:40:36 +08:00
|
|
|
os << ", ";
|
|
|
|
printBBName(inst->getFalseDest());
|
2018-07-25 06:01:27 +08:00
|
|
|
if (inst->getNumFalseOperands() != 0) {
|
|
|
|
os << '(';
|
|
|
|
interleaveComma(inst->getFalseOperands(),
|
|
|
|
[&](const CFGValue *operand) { printValueID(operand); });
|
|
|
|
os << " : ";
|
|
|
|
interleaveComma(inst->getFalseOperands(), [&](const CFGValue *operand) {
|
2018-07-25 07:07:22 +08:00
|
|
|
printType(operand->getType());
|
2018-07-25 06:01:27 +08:00
|
|
|
});
|
|
|
|
os << ")";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-07-22 05:32:09 +08:00
|
|
|
void CFGFunctionPrinter::print(const ReturnInst *inst) {
|
2018-07-27 09:09:20 +08:00
|
|
|
os << "return";
|
2018-07-22 05:32:09 +08:00
|
|
|
|
2018-08-02 23:28:20 +08:00
|
|
|
if (inst->getNumOperands() == 0)
|
|
|
|
return;
|
2018-07-22 05:32:09 +08:00
|
|
|
|
2018-08-02 23:28:20 +08:00
|
|
|
os << ' ';
|
2018-07-25 06:01:27 +08:00
|
|
|
interleaveComma(inst->getOperands(),
|
|
|
|
[&](const CFGValue *operand) { printValueID(operand); });
|
|
|
|
os << " : ";
|
2018-07-23 12:02:26 +08:00
|
|
|
interleaveComma(inst->getOperands(), [&](const CFGValue *operand) {
|
2018-07-25 07:07:22 +08:00
|
|
|
printType(operand->getType());
|
2018-07-22 05:32:09 +08:00
|
|
|
});
|
|
|
|
}
|
2018-07-19 01:16:05 +08:00
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void ModulePrinter::print(const CFGFunction *fn) {
|
|
|
|
CFGFunctionPrinter(fn, *this).print();
|
2018-06-29 11:45:33 +08:00
|
|
|
}
|
|
|
|
|
2018-06-24 07:03:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-07-04 08:51:28 +08:00
|
|
|
// ML Function printing
|
2018-06-24 07:03:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2018-07-04 08:51:28 +08:00
|
|
|
namespace {
|
2018-07-25 07:07:22 +08:00
|
|
|
class MLFunctionPrinter : public FunctionPrinter {
|
2018-07-04 08:51:28 +08:00
|
|
|
public:
|
2018-07-21 00:35:47 +08:00
|
|
|
MLFunctionPrinter(const MLFunction *function, const ModulePrinter &other);
|
2018-07-04 08:51:28 +08:00
|
|
|
|
|
|
|
const MLFunction *getFunction() const { return function; }
|
|
|
|
|
2018-08-25 14:38:14 +08:00
|
|
|
// Prints ML function.
|
2018-07-04 08:51:28 +08:00
|
|
|
void print();
|
|
|
|
|
2018-08-25 14:38:14 +08:00
|
|
|
// Prints ML function signature.
|
2018-08-07 02:54:39 +08:00
|
|
|
void printFunctionSignature();
|
|
|
|
|
2018-08-25 14:38:14 +08:00
|
|
|
// Methods to print ML function statements.
|
2018-07-04 08:51:28 +08:00
|
|
|
void print(const Statement *stmt);
|
2018-07-10 08:42:46 +08:00
|
|
|
void print(const OperationStmt *stmt);
|
2018-07-04 08:51:28 +08:00
|
|
|
void print(const ForStmt *stmt);
|
|
|
|
void print(const IfStmt *stmt);
|
2018-07-14 04:03:13 +08:00
|
|
|
void print(const StmtBlock *block);
|
2018-07-04 08:51:28 +08:00
|
|
|
|
2018-08-25 14:38:14 +08:00
|
|
|
// Print loop bounds.
|
|
|
|
void printDimAndSymbolList(ArrayRef<StmtOperand> ops, unsigned numDims);
|
|
|
|
void printBound(AffineBound bound, const char *prefix);
|
|
|
|
|
|
|
|
// Number of spaces used for indenting nested statements.
|
2018-07-14 04:03:13 +08:00
|
|
|
const static unsigned indentWidth = 2;
|
2018-07-04 08:51:28 +08:00
|
|
|
|
2018-07-14 04:03:13 +08:00
|
|
|
private:
|
2018-07-27 09:09:20 +08:00
|
|
|
void numberValues();
|
|
|
|
|
2018-07-04 08:51:28 +08:00
|
|
|
const MLFunction *function;
|
|
|
|
int numSpaces;
|
|
|
|
};
|
2018-07-24 02:44:40 +08:00
|
|
|
} // end anonymous namespace
|
2018-07-04 08:51:28 +08:00
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
|
|
|
|
const ModulePrinter &other)
|
2018-07-27 09:09:20 +08:00
|
|
|
: FunctionPrinter(other), function(function), numSpaces(0) {
|
2018-08-07 02:54:39 +08:00
|
|
|
assert(function && "Cannot print nullptr function");
|
2018-07-27 09:09:20 +08:00
|
|
|
numberValues();
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Number all of the SSA values in this ML function.
|
|
|
|
void MLFunctionPrinter::numberValues() {
|
2018-08-25 14:38:14 +08:00
|
|
|
// Numbers ML function arguments.
|
2018-08-07 02:54:39 +08:00
|
|
|
for (auto *arg : function->getArguments())
|
|
|
|
numberValueID(arg);
|
|
|
|
|
|
|
|
// Walks ML function statements and numbers for statements and
|
|
|
|
// the first result of the operation statements.
|
2018-07-28 01:58:14 +08:00
|
|
|
struct NumberValuesPass : public StmtWalker<NumberValuesPass> {
|
2018-07-27 09:09:20 +08:00
|
|
|
NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {}
|
|
|
|
void visitOperationStmt(OperationStmt *stmt) {
|
|
|
|
if (stmt->getNumResults() != 0)
|
|
|
|
printer->numberValueID(stmt->getResult(0));
|
|
|
|
}
|
2018-07-31 22:40:14 +08:00
|
|
|
void visitForStmt(ForStmt *stmt) { printer->numberValueID(stmt); }
|
2018-07-27 09:09:20 +08:00
|
|
|
MLFunctionPrinter *printer;
|
|
|
|
};
|
|
|
|
|
|
|
|
NumberValuesPass pass(this);
|
2018-08-16 00:09:54 +08:00
|
|
|
// TODO: it'd be cleaner to have constant visitor instead of using const_cast.
|
2018-07-28 01:58:14 +08:00
|
|
|
pass.walk(const_cast<MLFunction *>(function));
|
2018-07-27 09:09:20 +08:00
|
|
|
}
|
2018-07-04 08:51:28 +08:00
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void MLFunctionPrinter::print() {
|
2018-07-04 08:51:28 +08:00
|
|
|
os << "mlfunc ";
|
2018-08-07 02:54:39 +08:00
|
|
|
printFunctionSignature();
|
2018-09-19 07:36:26 +08:00
|
|
|
printFunctionAttributes(getFunction());
|
2018-07-04 08:51:28 +08:00
|
|
|
os << " {\n";
|
2018-07-14 04:03:13 +08:00
|
|
|
print(function);
|
2018-07-04 08:51:28 +08:00
|
|
|
os << "}\n\n";
|
|
|
|
}
|
|
|
|
|
2018-08-07 02:54:39 +08:00
|
|
|
void MLFunctionPrinter::printFunctionSignature() {
|
|
|
|
auto type = function->getType();
|
|
|
|
|
|
|
|
os << "@" << function->getName() << '(';
|
|
|
|
|
|
|
|
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
|
|
|
|
if (i > 0)
|
|
|
|
os << ", ";
|
|
|
|
auto *arg = function->getArgument(i);
|
|
|
|
printOperand(arg);
|
|
|
|
os << " : ";
|
|
|
|
printType(arg->getType());
|
|
|
|
}
|
|
|
|
os << ")";
|
|
|
|
printFunctionResultType(type);
|
|
|
|
}
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void MLFunctionPrinter::print(const StmtBlock *block) {
|
2018-07-14 04:03:13 +08:00
|
|
|
numSpaces += indentWidth;
|
2018-07-15 15:06:54 +08:00
|
|
|
for (auto &stmt : block->getStatements()) {
|
2018-07-14 04:03:13 +08:00
|
|
|
print(&stmt);
|
2018-07-15 15:06:54 +08:00
|
|
|
os << "\n";
|
|
|
|
}
|
2018-07-14 04:03:13 +08:00
|
|
|
numSpaces -= indentWidth;
|
|
|
|
}
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void MLFunctionPrinter::print(const Statement *stmt) {
|
2018-07-04 08:51:28 +08:00
|
|
|
switch (stmt->getKind()) {
|
2018-07-17 02:47:09 +08:00
|
|
|
case Statement::Kind::Operation:
|
|
|
|
return print(cast<OperationStmt>(stmt));
|
2018-07-04 08:51:28 +08:00
|
|
|
case Statement::Kind::For:
|
|
|
|
return print(cast<ForStmt>(stmt));
|
|
|
|
case Statement::Kind::If:
|
|
|
|
return print(cast<IfStmt>(stmt));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void MLFunctionPrinter::print(const OperationStmt *stmt) {
|
2018-07-27 09:09:20 +08:00
|
|
|
os.indent(numSpaces);
|
2018-07-21 00:35:47 +08:00
|
|
|
printOperation(stmt);
|
|
|
|
}
|
2018-07-04 08:51:28 +08:00
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void MLFunctionPrinter::print(const ForStmt *stmt) {
|
2018-07-31 06:18:10 +08:00
|
|
|
os.indent(numSpaces) << "for ";
|
2018-07-31 22:40:14 +08:00
|
|
|
printOperand(stmt);
|
2018-08-25 14:38:14 +08:00
|
|
|
os << " = ";
|
|
|
|
printBound(stmt->getLowerBound(), "max");
|
|
|
|
os << " to ";
|
|
|
|
printBound(stmt->getUpperBound(), "min");
|
|
|
|
|
Extend loop unrolling to unroll by a given factor; add builder for affine
apply op.
- add builder for AffineApplyOp (first one for an operation that has
non-zero operands)
- add support for loop unrolling by a given factor; uses the affine apply op
builder.
While on this, change 'step' of ForStmt to be 'unsigned' instead of
AffineConstantExpr *. Add setters for ForStmt lb, ub, step.
Sample Input:
// CHECK-LABEL: mlfunc @loop_nest_unroll_cleanup() {
mlfunc @loop_nest_unroll_cleanup() {
for %i = 1 to 100 {
for %j = 0 to 17 {
%x = "addi32"(%j, %j) : (affineint, affineint) -> i32
%y = "addi32"(%x, %x) : (i32, i32) -> i32
}
}
return
}
Output:
$ mlir-opt -loop-unroll -unroll-factor=4 /tmp/single2.mlir
#map0 = (d0) -> (d0 + 1)
#map1 = (d0) -> (d0 + 2)
#map2 = (d0) -> (d0 + 3)
mlfunc @loop_nest_unroll_cleanup() {
for %i0 = 1 to 100 {
for %i1 = 0 to 17 step 4 {
%0 = "addi32"(%i1, %i1) : (affineint, affineint) -> i32
%1 = "addi32"(%0, %0) : (i32, i32) -> i32
%2 = affine_apply #map0(%i1)
%3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
%4 = affine_apply #map1(%i1)
%5 = "addi32"(%4, %4) : (affineint, affineint) -> i32
%6 = affine_apply #map2(%i1)
%7 = "addi32"(%6, %6) : (affineint, affineint) -> i32
}
for %i2 = 16 to 17 {
%8 = "addi32"(%i2, %i2) : (affineint, affineint) -> i32
%9 = "addi32"(%8, %8) : (i32, i32) -> i32
}
}
return
}
PiperOrigin-RevId: 209676220
2018-08-22 07:01:23 +08:00
|
|
|
if (stmt->getStep() != 1)
|
|
|
|
os << " step " << stmt->getStep();
|
2018-07-20 00:52:39 +08:00
|
|
|
|
|
|
|
os << " {\n";
|
2018-07-14 04:03:13 +08:00
|
|
|
print(static_cast<const StmtBlock *>(stmt));
|
2018-07-17 02:47:09 +08:00
|
|
|
os.indent(numSpaces) << "}";
|
2018-07-04 08:51:28 +08:00
|
|
|
}
|
|
|
|
|
2018-08-25 14:38:14 +08:00
|
|
|
void MLFunctionPrinter::printDimAndSymbolList(ArrayRef<StmtOperand> ops,
|
|
|
|
unsigned numDims) {
|
|
|
|
auto printComma = [&]() { os << ", "; };
|
|
|
|
os << '(';
|
|
|
|
interleave(ops.begin(), ops.begin() + numDims,
|
|
|
|
[&](const StmtOperand &v) { printOperand(v.get()); }, printComma);
|
|
|
|
os << ')';
|
|
|
|
|
|
|
|
if (numDims < ops.size()) {
|
|
|
|
os << '[';
|
|
|
|
interleave(ops.begin() + numDims, ops.end(),
|
|
|
|
[&](const StmtOperand &v) { printOperand(v.get()); },
|
|
|
|
printComma);
|
|
|
|
os << ']';
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
|
2018-10-10 07:39:24 +08:00
|
|
|
AffineMap map = bound.getMap();
|
2018-08-25 14:38:14 +08:00
|
|
|
|
|
|
|
// Check if this bound should be printed using short-hand notation.
|
2018-09-26 08:15:54 +08:00
|
|
|
// The decision to restrict printing short-hand notation to trivial cases
|
|
|
|
// comes from the will to roundtrip MLIR binary -> text -> binary in a
|
|
|
|
// lossless way.
|
|
|
|
// Therefore, short-hand parsing and printing is only supported for
|
|
|
|
// zero-operand constant maps and single symbol operand identity maps.
|
2018-10-10 07:39:24 +08:00
|
|
|
if (map.getNumResults() == 1) {
|
|
|
|
AffineExpr expr = map.getResult(0);
|
2018-08-25 14:38:14 +08:00
|
|
|
|
|
|
|
// Print constant bound.
|
2018-10-10 07:39:24 +08:00
|
|
|
if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
|
2018-10-09 04:47:18 +08:00
|
|
|
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
|
2018-10-10 01:59:27 +08:00
|
|
|
os << constExpr.getValue();
|
2018-09-26 08:15:54 +08:00
|
|
|
return;
|
|
|
|
}
|
2018-08-25 14:38:14 +08:00
|
|
|
}
|
|
|
|
|
2018-09-26 08:15:54 +08:00
|
|
|
// Print bound that consists of a single SSA symbol if the map is over a
|
|
|
|
// single symbol.
|
2018-10-10 07:39:24 +08:00
|
|
|
if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
|
2018-10-09 04:47:18 +08:00
|
|
|
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
|
2018-09-26 08:15:54 +08:00
|
|
|
printOperand(bound.getOperand(0));
|
|
|
|
return;
|
|
|
|
}
|
2018-08-25 14:38:14 +08:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// Map has multiple results. Print 'min' or 'max' prefix.
|
|
|
|
os << prefix << ' ';
|
|
|
|
}
|
|
|
|
|
2018-09-26 08:15:54 +08:00
|
|
|
// Print the map and its operands.
|
2018-08-25 14:38:14 +08:00
|
|
|
printAffineMapReference(map);
|
2018-10-10 07:39:24 +08:00
|
|
|
printDimAndSymbolList(bound.getStmtOperands(), map.getNumDims());
|
2018-08-25 14:38:14 +08:00
|
|
|
}
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void MLFunctionPrinter::print(const IfStmt *stmt) {
|
2018-08-29 06:26:20 +08:00
|
|
|
os.indent(numSpaces) << "if ";
|
2018-10-11 00:45:59 +08:00
|
|
|
IntegerSet set = stmt->getIntegerSet();
|
2018-08-29 06:26:20 +08:00
|
|
|
printIntegerSetReference(set);
|
2018-10-11 00:45:59 +08:00
|
|
|
printDimAndSymbolList(stmt->getStmtOperands(), set.getNumDims());
|
2018-08-29 06:26:20 +08:00
|
|
|
os << " {\n";
|
2018-08-09 02:14:57 +08:00
|
|
|
print(stmt->getThen());
|
2018-07-14 04:03:13 +08:00
|
|
|
os.indent(numSpaces) << "}";
|
2018-08-09 02:14:57 +08:00
|
|
|
if (stmt->hasElse()) {
|
2018-07-14 04:03:13 +08:00
|
|
|
os << " else {\n";
|
2018-08-09 02:14:57 +08:00
|
|
|
print(stmt->getElse());
|
2018-07-14 04:03:13 +08:00
|
|
|
os.indent(numSpaces) << "}";
|
|
|
|
}
|
2018-07-04 08:51:28 +08:00
|
|
|
}
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void ModulePrinter::print(const MLFunction *fn) {
|
|
|
|
MLFunctionPrinter(fn, *this).print();
|
2018-07-18 07:56:54 +08:00
|
|
|
}
|
|
|
|
|
2018-07-04 08:51:28 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// print and dump methods
|
|
|
|
//===----------------------------------------------------------------------===//
|
2018-06-29 11:45:33 +08:00
|
|
|
|
2018-07-19 07:29:21 +08:00
|
|
|
void Attribute::print(raw_ostream &os) const {
|
2018-07-21 00:35:47 +08:00
|
|
|
ModuleState state(/*no context is known*/ nullptr);
|
2018-07-25 07:07:22 +08:00
|
|
|
ModulePrinter(os, state).printAttribute(this);
|
2018-07-19 07:29:21 +08:00
|
|
|
}
|
|
|
|
|
2018-07-24 02:44:40 +08:00
|
|
|
void Attribute::dump() const { print(llvm::errs()); }
|
2018-07-19 07:29:21 +08:00
|
|
|
|
2018-07-18 07:56:54 +08:00
|
|
|
void Type::print(raw_ostream &os) const {
|
2018-07-21 00:35:47 +08:00
|
|
|
ModuleState state(getContext());
|
2018-07-25 07:07:22 +08:00
|
|
|
ModulePrinter(os, state).printType(this);
|
2018-07-18 07:56:54 +08:00
|
|
|
}
|
|
|
|
|
2018-07-19 01:16:05 +08:00
|
|
|
void Type::dump() const { print(llvm::errs()); }
|
2018-07-18 07:56:54 +08:00
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
void AffineMap::dump() const {
|
2018-07-17 00:45:22 +08:00
|
|
|
print(llvm::errs());
|
|
|
|
llvm::errs() << "\n";
|
|
|
|
}
|
2018-07-10 00:00:25 +08:00
|
|
|
|
2018-10-11 00:45:59 +08:00
|
|
|
void IntegerSet::dump() const {
|
2018-08-08 05:24:38 +08:00
|
|
|
print(llvm::errs());
|
|
|
|
llvm::errs() << "\n";
|
|
|
|
}
|
|
|
|
|
2018-10-10 01:59:27 +08:00
|
|
|
void AffineExpr::print(raw_ostream &os) const {
|
2018-10-22 10:49:31 +08:00
|
|
|
ModuleState state(getContext());
|
2018-10-10 01:59:27 +08:00
|
|
|
ModulePrinter(os, state).printAffineExpr(*this);
|
|
|
|
}
|
|
|
|
|
|
|
|
void AffineExpr::dump() const {
|
|
|
|
print(llvm::errs());
|
|
|
|
llvm::errs() << "\n";
|
2018-06-30 09:09:29 +08:00
|
|
|
}
|
|
|
|
|
2018-10-10 07:39:24 +08:00
|
|
|
void AffineMap::print(raw_ostream &os) const {
|
2018-10-23 02:47:10 +08:00
|
|
|
ModuleState state(getContext());
|
2018-10-10 07:39:24 +08:00
|
|
|
ModulePrinter(os, state).printAffineMap(*this);
|
2018-07-21 00:35:47 +08:00
|
|
|
}
|
2018-07-12 12:31:07 +08:00
|
|
|
|
2018-10-11 00:45:59 +08:00
|
|
|
void IntegerSet::print(raw_ostream &os) const {
|
2018-08-08 05:24:38 +08:00
|
|
|
ModuleState state(/*no context is known*/ nullptr);
|
2018-10-11 00:45:59 +08:00
|
|
|
ModulePrinter(os, state).printIntegerSet(*this);
|
2018-08-08 05:24:38 +08:00
|
|
|
}
|
|
|
|
|
2018-08-03 08:16:58 +08:00
|
|
|
void SSAValue::print(raw_ostream &os) const {
|
|
|
|
switch (getKind()) {
|
|
|
|
case SSAValueKind::BBArgument:
|
|
|
|
// TODO: Improve this.
|
|
|
|
os << "<bb argument>\n";
|
|
|
|
return;
|
|
|
|
case SSAValueKind::InstResult:
|
|
|
|
return getDefiningInst()->print(os);
|
2018-08-07 02:54:39 +08:00
|
|
|
case SSAValueKind::MLFuncArgument:
|
2018-08-03 08:16:58 +08:00
|
|
|
// TODO: Improve this.
|
|
|
|
os << "<function argument>\n";
|
|
|
|
return;
|
|
|
|
case SSAValueKind::StmtResult:
|
|
|
|
return getDefiningStmt()->print(os);
|
|
|
|
case SSAValueKind::ForStmt:
|
|
|
|
return cast<ForStmt>(this)->print(os);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void SSAValue::dump() const { print(llvm::errs()); }
|
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void Instruction::print(raw_ostream &os) const {
|
2018-08-04 02:12:34 +08:00
|
|
|
if (!getFunction()) {
|
|
|
|
os << "<<UNLINKED INSTRUCTION>>\n";
|
|
|
|
return;
|
|
|
|
}
|
2018-07-21 00:35:47 +08:00
|
|
|
ModuleState state(getFunction()->getContext());
|
|
|
|
ModulePrinter modulePrinter(os, state);
|
|
|
|
CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
|
|
|
|
}
|
2018-07-12 12:31:07 +08:00
|
|
|
|
2018-07-21 00:35:47 +08:00
|
|
|
void Instruction::dump() const {
|
|
|
|
print(llvm::errs());
|
|
|
|
llvm::errs() << "\n";
|
2018-06-30 09:09:29 +08:00
|
|
|
}
|
|
|
|
|
2018-06-24 07:03:42 +08:00
|
|
|
void BasicBlock::print(raw_ostream &os) const {
|
2018-08-04 02:12:34 +08:00
|
|
|
if (!getFunction()) {
|
|
|
|
os << "<<UNLINKED BLOCK>>\n";
|
|
|
|
return;
|
|
|
|
}
|
2018-07-21 00:35:47 +08:00
|
|
|
ModuleState state(getFunction()->getContext());
|
|
|
|
ModulePrinter modulePrinter(os, state);
|
|
|
|
CFGFunctionPrinter(getFunction(), modulePrinter).print(this);
|
2018-06-24 07:03:42 +08:00
|
|
|
}
|
|
|
|
|
2018-07-19 01:16:05 +08:00
|
|
|
void BasicBlock::dump() const { print(llvm::errs()); }
|
2018-06-24 07:03:42 +08:00
|
|
|
|
2018-09-22 05:40:36 +08:00
|
|
|
/// Print out the name of the basic block without printing its body.
|
|
|
|
void BasicBlock::printAsOperand(raw_ostream &os, bool printType) {
|
|
|
|
if (!getFunction()) {
|
|
|
|
os << "<<UNLINKED BLOCK>>\n";
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
ModuleState state(getFunction()->getContext());
|
|
|
|
ModulePrinter modulePrinter(os, state);
|
|
|
|
CFGFunctionPrinter(getFunction(), modulePrinter).printBBName(this);
|
|
|
|
}
|
|
|
|
|
2018-07-04 08:51:28 +08:00
|
|
|
void Statement::print(raw_ostream &os) const {
|
2018-08-01 14:14:16 +08:00
|
|
|
MLFunction *function = findFunction();
|
2018-08-04 02:12:34 +08:00
|
|
|
if (!function) {
|
|
|
|
os << "<<UNLINKED STATEMENT>>\n";
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2018-08-01 14:14:16 +08:00
|
|
|
ModuleState state(function->getContext());
|
2018-07-21 00:35:47 +08:00
|
|
|
ModulePrinter modulePrinter(os, state);
|
2018-08-01 14:14:16 +08:00
|
|
|
MLFunctionPrinter(function, modulePrinter).print(this);
|
2018-06-29 08:02:32 +08:00
|
|
|
}
|
|
|
|
|
2018-07-19 01:16:05 +08:00
|
|
|
void Statement::dump() const { print(llvm::errs()); }
|
2018-07-15 15:06:54 +08:00
|
|
|
|
2018-08-09 02:14:57 +08:00
|
|
|
void StmtBlock::printBlock(raw_ostream &os) const {
|
2018-08-04 04:22:26 +08:00
|
|
|
MLFunction *function = findFunction();
|
|
|
|
ModuleState state(function->getContext());
|
|
|
|
ModulePrinter modulePrinter(os, state);
|
|
|
|
MLFunctionPrinter(function, modulePrinter).print(this);
|
|
|
|
}
|
|
|
|
|
2018-08-09 02:14:57 +08:00
|
|
|
void StmtBlock::dumpBlock() const { printBlock(llvm::errs()); }
|
2018-08-04 04:22:26 +08:00
|
|
|
|
2018-06-24 07:03:42 +08:00
|
|
|
void Function::print(raw_ostream &os) const {
|
2018-07-21 00:35:47 +08:00
|
|
|
ModuleState state(getContext());
|
|
|
|
ModulePrinter(os, state).print(this);
|
2018-06-24 07:03:42 +08:00
|
|
|
}
|
|
|
|
|
2018-07-19 01:16:05 +08:00
|
|
|
void Function::dump() const { print(llvm::errs()); }
|
|
|
|
|
2018-06-24 07:03:42 +08:00
|
|
|
void Module::print(raw_ostream &os) const {
|
2018-07-21 00:35:47 +08:00
|
|
|
ModuleState state(getContext());
|
|
|
|
state.initialize(this);
|
|
|
|
ModulePrinter(os, state).print(this);
|
2018-06-24 07:03:42 +08:00
|
|
|
}
|
|
|
|
|
2018-07-19 01:16:05 +08:00
|
|
|
void Module::dump() const { print(llvm::errs()); }
|