Use SFINAE to generalize << overloads, give 'constant' a pretty form,

generalize the asmprinters handling of pretty names to allow arbitrary sugar to
be dumped on various constructs.  Give CFG function arguments nice "arg0" names
like MLFunctions get, and give constant integers pretty names like %c37 for a
constant 377

PiperOrigin-RevId: 206953080
This commit is contained in:
Chris Lattner 2018-08-01 10:43:18 -07:00 committed by jpienaar
parent 48dbfb48d5
commit 8eaf382734
11 changed files with 224 additions and 94 deletions

View File

@ -98,32 +98,16 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const AffineMap &map) {
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, StringRef other) {
p.getStream() << other;
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const char *other) {
p.getStream() << other;
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, char other) {
p.getStream() << other;
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, unsigned other) {
p.getStream() << other;
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, int other) {
p.getStream() << other;
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, float other) {
// Support printing anything that isn't convertible to one of the above types,
// even if it isn't exactly one of them. For example, we want to print
// FunctionType with the Type& version above, not have it match this.
template <typename T, typename std::enable_if<
!std::is_convertible<T &, SSAValue &>::value &&
!std::is_convertible<T &, Type &>::value &&
!std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, AffineMap &>::value,
T>::type * = nullptr>
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
p.getStream() << other;
return p;
}

View File

@ -30,6 +30,7 @@
namespace mlir {
class OperationInst;
class OperationStmt;
class Operation;
/// This enumerates all of the SSA value kinds in the MLIR system.
enum class SSAValueKind {
@ -72,6 +73,13 @@ public:
return const_cast<SSAValue *>(this)->getDefiningStmt();
}
/// If this value is the result of an Operation, return the operation that
/// defines it.
Operation *getDefiningOperation();
const Operation *getDefiningOperation() const {
return const_cast<SSAValue *>(this)->getDefiningOperation();
}
protected:
SSAValue(SSAValueKind kind, Type *type) : typeAndKind(type, kind) {}
private:

View File

@ -140,6 +140,8 @@ public:
static StringRef getOperationName() { return "constant"; }
// Hooks to customize behavior of this op.
static OpAsmParserResult parse(OpAsmParser *parser);
void print(OpAsmPrinter *p) const;
const char *verify() const;
protected:

View File

@ -28,12 +28,15 @@
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir;
void Identifier::print(raw_ostream &os) const { os << str(); }
@ -574,24 +577,75 @@ public:
void printOperand(const SSAValue *value) { printValueID(value); }
enum { nameSentinel = ~0U };
protected:
void numberValueID(const SSAValue *value) {
assert(!valueIDs.count(value) && "Value numbered multiple times");
unsigned id;
switch (value->getKind()) {
case SSAValueKind::BBArgument:
case SSAValueKind::InstResult:
case SSAValueKind::StmtResult:
id = nextValueID++;
break;
case SSAValueKind::FnArgument:
id = nextFnArgumentID++;
break;
case SSAValueKind::ForStmt:
id = nextLoopID++;
break;
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
// Give constant integers special names.
if (auto *op = value->getDefiningOperation()) {
if (auto intOp = op->getAs<ConstantIntOp>()) {
specialName << 'c' << intOp->getValue();
if (!intOp->getType()->isAffineInt())
specialName << '_' << *intOp->getType();
}
}
if (specialNameBuffer.empty()) {
switch (value->getKind()) {
case SSAValueKind::BBArgument:
// If this is an argument to the function, give it an 'arg' name.
if (auto *bb = cast<BBArgument>(value)->getOwner())
if (auto *fn = bb->getFunction())
if (&fn->front() == bb) {
specialName << "arg" << nextArgumentID++;
break;
}
// Otherwise number it normally.
LLVM_FALLTHROUGH;
case SSAValueKind::InstResult:
case SSAValueKind::StmtResult:
// This is an uninteresting result, give it a boring number and be
// done with it.
valueIDs[value] = nextValueID++;
return;
case SSAValueKind::FnArgument:
specialName << "arg" << nextArgumentID++;
break;
case SSAValueKind::ForStmt:
specialName << 'i' << nextLoopID++;
break;
}
}
// Ok, this value had an interesting name. Remember it with a sentinel.
valueIDs[value] = nameSentinel;
// Remember that we've used this name, checking to see if we had a conflict.
auto insertRes = usedNames.insert(specialName.str());
if (insertRes.second) {
// If this is the first use of the name, then we're successful!
valueNames[value] = insertRes.first->first();
return;
}
// Otherwise, we had a conflict - probe until we find a unique name. This
// is guaranteed to terminate (and usually in a single iteration) because it
// generates new names by incrementing nextConflictID.
while (1) {
std::string probeName =
specialName.str().str() + "_" + llvm::utostr(nextConflictID++);
insertRes = usedNames.insert(probeName);
if (insertRes.second) {
// If this is the first use of the name, then we're successful!
valueNames[value] = insertRes.first->first();
return;
}
}
valueIDs[value] = id;
}
void printValueID(const SSAValue *value, bool printResultNo = true) const {
@ -620,22 +674,37 @@ protected:
}
os << '%';
if (isa<ForStmt>(value))
if (it->second != nameSentinel) {
os << it->second;
} else {
auto nameIt = valueNames.find(lookupValue);
assert(nameIt != valueNames.end() && "Didn't have a name entry?");
os << nameIt->second;
}
os << 'i';
else if (isa<FnArgument>(value))
os << "arg";
os << it->getSecond();
if (resultNo != -1 && printResultNo)
os << '#' << resultNo;
}
private:
/// This is the value ID for each SSA value in the current function.
/// This is the value ID for each SSA value in the current function. If this
/// returns ~0, then the valueID has an entry in valueNames.
DenseMap<const SSAValue *, unsigned> valueIDs;
DenseMap<const SSAValue *, StringRef> valueNames;
/// This keeps track of all of the non-numeric names that are in flight,
/// allowing us to check for duplicates.
llvm::StringSet<> usedNames;
/// This is the next value ID to assign in numbering.
unsigned nextValueID = 0;
/// This is the ID to assign to the next induction variable.
unsigned nextLoopID = 0;
unsigned nextFnArgumentID = 0;
/// This is the next ID to assign to an MLFunction argument.
unsigned nextArgumentID = 0;
/// This is the next ID to assign when a name conflict is detected.
unsigned nextConflictID = 0;
};
} // end anonymous namespace

View File

@ -209,14 +209,6 @@ void OperationInst::eraseFromBlock() {
getBlock()->getOperations().erase(this);
}
/// If this value is the result of an OperationInst, return the instruction
/// that defines it.
OperationInst *SSAValue::getDefiningInst() {
if (auto *result = dyn_cast<InstResult>(this))
return result->getOwner();
return nullptr;
}
//===----------------------------------------------------------------------===//
// TerminatorInst
//===----------------------------------------------------------------------===//

45
mlir/lib/IR/SSAValue.cpp Normal file
View File

@ -0,0 +1,45 @@
//===- Instructions.cpp - MLIR CFGFunction Instruction Classes ------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/Statements.h"
using namespace mlir;
/// If this value is the result of an OperationInst, return the instruction
/// that defines it.
OperationInst *SSAValue::getDefiningInst() {
if (auto *result = dyn_cast<InstResult>(this))
return result->getOwner();
return nullptr;
}
/// If this value is the result of an OperationStmt, return the statement
/// that defines it.
OperationStmt *SSAValue::getDefiningStmt() {
if (auto *result = dyn_cast<StmtResult>(this))
return result->getOwner();
return nullptr;
}
Operation *SSAValue::getDefiningOperation() {
if (auto *inst = getDefiningInst())
return inst;
if (auto *stmt = getDefiningStmt())
return stmt;
return nullptr;
}

View File

@ -190,6 +190,22 @@ const char *AllocOp::verify() const {
return nullptr;
}
void ConstantOp::print(OpAsmPrinter *p) const {
*p << "constant " << *getValue() << " : " << *getType();
}
OpAsmParserResult ConstantOp::parse(OpAsmParser *parser) {
Attribute *valueAttr;
Type *type;
if (parser->parseAttribute(valueAttr) || parser->parseColonType(type))
return {};
auto &builder = parser->getBuilder();
return OpAsmParserResult(
/*operands=*/{}, type,
NamedAttribute(builder.getIdentifier("value"), valueAttr));
}
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
const char *ConstantOp::verify() const {

View File

@ -199,14 +199,6 @@ void OperationStmt::dropAllReferences() {
op.drop();
}
/// If this value is the result of an OperationStmt, return the statement
/// that defines it.
OperationStmt *SSAValue::getDefiningStmt() {
if (auto *result = dyn_cast<StmtResult>(this))
return result->getOwner();
return nullptr;
}
//===----------------------------------------------------------------------===//
// ForStmt
//===----------------------------------------------------------------------===//

View File

@ -11,13 +11,13 @@
// CHECK-LABEL: cfgfunc @cfgfunc_with_ops(f32) {
cfgfunc @cfgfunc_with_ops(f32) {
bb0(%a : f32):
// CHECK: %1 = "getTensor"() : () -> tensor<4x4x?xf32>
// CHECK: %0 = "getTensor"() : () -> tensor<4x4x?xf32>
%t = "getTensor"() : () -> tensor<4x4x?xf32>
// CHECK: %2 = dim %1, 2 : tensor<4x4x?xf32>
// CHECK: %1 = dim %0, 2 : tensor<4x4x?xf32>
%t2 = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> affineint
// CHECK: %3 = addf %0, %0 : f32
// CHECK: %2 = addf %arg0, %arg0 : f32
%x = "addf"(%a, %a) : (f32,f32) -> (f32)
// CHECK: return
@ -26,22 +26,25 @@ bb0(%a : f32):
// CHECK-LABEL: cfgfunc @standard_instrs(tensor<4x4x?xf32>, f32) {
cfgfunc @standard_instrs(tensor<4x4x?xf32>, f32) {
// CHECK: bb0(%0: tensor<4x4x?xf32>, %1: f32):
// CHECK: bb0(%arg0: tensor<4x4x?xf32>, %arg1: f32):
bb42(%t: tensor<4x4x?xf32>, %f: f32):
// CHECK: %2 = dim %0, 2 : tensor<4x4x?xf32>
// CHECK: %0 = dim %arg0, 2 : tensor<4x4x?xf32>
%a = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> affineint
// CHECK: %3 = dim %0, 2 : tensor<4x4x?xf32>
// CHECK: %1 = dim %arg0, 2 : tensor<4x4x?xf32>
%a2 = dim %t, 2 : tensor<4x4x?xf32>
// CHECK: %4 = addf %1, %1 : f32
// CHECK: %2 = addf %arg1, %arg1 : f32
%f2 = "addf"(%f, %f) : (f32,f32) -> f32
// CHECK: %5 = addf %4, %4 : f32
// CHECK: %3 = addf %2, %2 : f32
%f3 = addf %f2, %f2 : f32
// CHECK: %6 = "constant"(){value: 42} : () -> i32
// CHECK: %c42_i32 = constant 42 : i32
%x = "constant"(){value: 42} : () -> i32
// CHECK: %c42_i32_0 = constant 42 : i32
%7 = constant 42 : i32
return
}
@ -51,18 +54,18 @@ bb0:
%i = "constant"() {value: 0} : () -> affineint
%j = "constant"() {value: 1} : () -> affineint
// CHECK: affine_apply #map0(%0)
// CHECK: affine_apply #map0(%c0)
%a = "affine_apply" (%i) { map: (d0) -> (d0 + 1) } :
(affineint) -> (affineint)
// CHECK: affine_apply #map1(%0, %1)
// CHECK: affine_apply #map1(%c0, %c1)
%b = "affine_apply" (%i, %j) { map: #map5 } :
(affineint, affineint) -> (affineint, affineint)
// CHECK: affine_apply #map2(%0, %1)[%1, %0]
// CHECK: affine_apply #map2(%c0, %c1)[%c1, %c0]
%c = affine_apply (i,j)[m,n] -> (i+n, j+m)(%i, %j)[%j, %i]
// CHECK: affine_apply #map3()[%0]
// CHECK: affine_apply #map3()[%c0]
%d = affine_apply ()[x] -> (x+1)()[%i]
return
@ -71,10 +74,10 @@ bb0:
// CHECK-LABEL: cfgfunc @load_store
cfgfunc @load_store(memref<4x4xi32>, affineint) {
bb0(%0: memref<4x4xi32>, %1: affineint):
// CHECK: %2 = load %0[%1, %1] : memref<4x4xi32>
// CHECK: %0 = load %arg0[%arg1, %arg1] : memref<4x4xi32>
%2 = "load"(%0, %1, %1) : (memref<4x4xi32>, affineint, affineint)->i32
// CHECK: %3 = load %0[%1, %1] : memref<4x4xi32>
// CHECK: %1 = load %arg0[%arg1, %arg1] : memref<4x4xi32>
%3 = load %0[%1, %1] : memref<4x4xi32>
return

View File

@ -14,15 +14,15 @@ bb0:
%2 = "constant"() {value: 1} : () -> affineint
// Test alloc with dynamic dimensions.
// CHECK: %3 = alloc(%1, %2) : memref<?x?xf32, #map0, 1>
// CHECK: %1 = alloc(%c0, %c1) : memref<?x?xf32, #map0, 1>
%3 = alloc(%1, %2) : memref<?x?xf32, (d0, d1) -> (d0, d1), 1>
// Test alloc with no dynamic dimensions and one symbol.
// CHECK: %4 = alloc()[%1] : memref<2x4xf32, #map1, 1>
// CHECK: %2 = alloc()[%c0] : memref<2x4xf32, #map1, 1>
%4 = alloc()[%1] : memref<2x4xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
// Test alloc with dynamic dimensions and one symbol.
// CHECK: %5 = alloc(%2)[%1] : memref<2x?xf32, #map1, 1>
// CHECK: %3 = alloc(%c1)[%c0] : memref<2x?xf32, #map1, 1>
%5 = alloc(%2)[%1] : memref<2x?xf32, (d0, d1)[s0] -> (d0 + s0, d1), 1>
// CHECK: return
@ -35,13 +35,13 @@ bb0:
// CHECK: %0 = alloc() : memref<1024x64xf32, #map0, 1>
%0 = alloc() : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
%1 = "constant"() {value: 0} : () -> affineint
%2 = "constant"() {value: 1} : () -> affineint
%1 = constant 0 : affineint
%2 = constant 1 : affineint
// CHECK: %3 = load %0[%1, %2] : memref<1024x64xf32, #map0, 1>
// CHECK: %1 = load %0[%c0, %c1] : memref<1024x64xf32, #map0, 1>
%3 = load %0[%1, %2] : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
// CHECK: store %3, %0[%1, %2] : memref<1024x64xf32, #map0, 1>
// CHECK: store %1, %0[%c0, %c1] : memref<1024x64xf32, #map0, 1>
store %3, %0[%1, %2] : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
return

View File

@ -68,22 +68,22 @@ extfunc @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<i8, #map1, 0>) -> ()
// CHECK-LABEL: cfgfunc @simpleCFG(i32, f32) -> i1 {
cfgfunc @simpleCFG(i32, f32) -> i1 {
// CHECK: bb0(%0: i32, %1: f32):
bb42 (%0: i32, %f: f32):
// CHECK: %2 = "foo"() : () -> i64
// CHECK: bb0(%arg0: i32, %arg1: f32):
bb42 (%arg0: i32, %f: f32):
// CHECK: %0 = "foo"() : () -> i64
%1 = "foo"() : ()->i64
// CHECK: "bar"(%2) : (i64) -> (i1, i1, i1)
// CHECK: "bar"(%0) : (i64) -> (i1, i1, i1)
%2 = "bar"(%1) : (i64) -> (i1,i1,i1)
// CHECK: return %3#1
// CHECK: return %1#1
return %2#1 : i1
// CHECK: }
}
// CHECK-LABEL: cfgfunc @simpleCFGUsingBBArgs(i32, i64) {
cfgfunc @simpleCFGUsingBBArgs(i32, i64) {
// CHECK: bb0(%0: i32, %1: i64):
bb42 (%0: i32, %f: i64):
// CHECK: "bar"(%1) : (i64) -> (i1, i1, i1)
// CHECK: bb0(%arg0: i32, %arg1: i64):
bb42 (%arg0: i32, %f: i64):
// CHECK: "bar"(%arg1) : (i64) -> (i1, i1, i1)
%2 = "bar"(%f) : (i64) -> (i1,i1,i1)
// CHECK: return
return
@ -263,3 +263,22 @@ bb2(%x2 : i64, %y2 : i32, %z2 : i32):
%z = "foo"() : () -> i32
return %z : i32
}
// Test pretty printing of constant names.
// CHECK-LABEL: cfgfunc @constants
cfgfunc @constants() -> (i32, i23, i23) {
bb0:
// CHECK: %c42_i32 = constant 42 : i32
%x = constant 42 : i32
// CHECK: %c17_i23 = constant 17 : i23
%y = constant 17 : i23
// This is a redundant definition of 17, the asmprinter gives it a unique name
// CHECK: %c17_i23_0 = constant 17 : i23
%z = constant 17 : i23
// CHECK: return %c42_i32, %c17_i23, %c17_i23_0
return %x, %y, %z : i32, i23, i23
}