Support printing SSA ids in affine.load/store which do not have special names.

PiperOrigin-RevId: 254997746
This commit is contained in:
Andy Davis 2019-06-25 10:29:53 -07:00 committed by A. Unique TensorFlower
parent 66ed7d6d83
commit 91f27d025b
3 changed files with 61 additions and 44 deletions

View File

@ -87,6 +87,8 @@ public:
/// Prints an affine map of SSA ids, where SSA id names are used in place
/// of dims/symbols.
/// Operand values must come from single-result sources, and be valid
/// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
ArrayRef<Value *> operands) = 0;
@ -380,6 +382,8 @@ public:
}
/// Parses an affine map attribute where dims and symbols are SSA operands.
/// Operand values must come from single-result sources, and be valid
/// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
virtual ParseResult
parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
StringRef attrName,

View File

@ -347,8 +347,9 @@ public:
void printLocation(LocationAttr loc);
void printAffineMap(AffineMap map);
void printAffineExpr(AffineExpr expr, ArrayRef<StringRef> dimValueNames = {},
ArrayRef<StringRef> symbolValueNames = {});
void printAffineExpr(
AffineExpr expr,
llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
void printAffineConstraint(AffineExpr expr, bool isEq);
void printIntegerSet(IntegerSet set);
@ -370,10 +371,9 @@ protected:
Weak, // + and -
Strong, // All other binary operators.
};
void printAffineExprInternal(AffineExpr expr,
BindingStrength enclosingTightness,
ArrayRef<StringRef> dimValueNames = {},
ArrayRef<StringRef> symbolValueNames = {});
void printAffineExprInternal(
AffineExpr expr, BindingStrength enclosingTightness,
llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
};
} // end anonymous namespace
@ -921,30 +921,28 @@ void ModulePrinter::printType(Type type) {
// Affine expressions and maps
//===----------------------------------------------------------------------===//
void ModulePrinter::printAffineExpr(AffineExpr expr,
ArrayRef<StringRef> dimValueNames,
ArrayRef<StringRef> symbolValueNames) {
printAffineExprInternal(expr, BindingStrength::Weak, dimValueNames,
symbolValueNames);
void ModulePrinter::printAffineExpr(
AffineExpr expr, llvm::function_ref<void(unsigned, bool)> printValueName) {
printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
}
void ModulePrinter::printAffineExprInternal(
AffineExpr expr, BindingStrength enclosingTightness,
ArrayRef<StringRef> dimValueNames, ArrayRef<StringRef> symbolValueNames) {
llvm::function_ref<void(unsigned, bool)> printValueName) {
const char *binopSpelling = nullptr;
switch (expr.getKind()) {
case AffineExprKind::SymbolId: {
unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
if (pos < symbolValueNames.size())
os << "symbol(%" << symbolValueNames[pos] << ')';
if (printValueName)
printValueName(pos, /*isSymbol=*/true);
else
os << 's' << pos;
return;
}
case AffineExprKind::DimId: {
unsigned pos = expr.cast<AffineDimExpr>().getPosition();
if (pos < dimValueNames.size())
os << '%' << dimValueNames[pos];
if (printValueName)
printValueName(pos, /*isSymbol=*/false);
else
os << 'd' << pos;
return;
@ -982,16 +980,14 @@ void ModulePrinter::printAffineExprInternal(
auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
if (rhsConst && rhsConst.getValue() == -1) {
os << "-";
printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames,
symbolValueNames);
printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
return;
}
printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames,
symbolValueNames);
printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
os << binopSpelling;
printAffineExprInternal(rhsExpr, BindingStrength::Strong, dimValueNames,
symbolValueNames);
printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
if (enclosingTightness == BindingStrength::Strong)
os << ')';
@ -1009,15 +1005,15 @@ void ModulePrinter::printAffineExprInternal(
AffineExpr rrhsExpr = rhs.getRHS();
if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
if (rrhs.getValue() == -1) {
printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
symbolValueNames);
printAffineExprInternal(lhsExpr, BindingStrength::Weak,
printValueName);
os << " - ";
if (rhs.getLHS().getKind() == AffineExprKind::Add) {
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
dimValueNames, symbolValueNames);
printValueName);
} else {
printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
dimValueNames, symbolValueNames);
printValueName);
}
if (enclosingTightness == BindingStrength::Strong)
@ -1026,11 +1022,11 @@ void ModulePrinter::printAffineExprInternal(
}
if (rrhs.getValue() < -1) {
printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
symbolValueNames);
printAffineExprInternal(lhsExpr, BindingStrength::Weak,
printValueName);
os << " - ";
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
dimValueNames, symbolValueNames);
printValueName);
os << " * " << -rrhs.getValue();
if (enclosingTightness == BindingStrength::Strong)
os << ')';
@ -1043,8 +1039,7 @@ void ModulePrinter::printAffineExprInternal(
// Pretty print addition to a negative number as a subtraction.
if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
if (rhsConst.getValue() < 0) {
printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
symbolValueNames);
printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
os << " - " << -rhsConst.getValue();
if (enclosingTightness == BindingStrength::Strong)
os << ')';
@ -1052,11 +1047,10 @@ void ModulePrinter::printAffineExprInternal(
}
}
printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
symbolValueNames);
printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
os << " + ";
printAffineExprInternal(rhsExpr, BindingStrength::Weak, dimValueNames,
symbolValueNames);
printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
if (enclosingTightness == BindingStrength::Strong)
os << ')';
@ -1242,16 +1236,18 @@ public:
ArrayRef<Value *> operands) {
AffineMap map = mapAttr.getValue();
unsigned numDims = map.getNumDims();
SmallVector<StringRef, 2> dimValueNames;
SmallVector<StringRef, 1> symbolValueNames;
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
if (i < numDims)
dimValueNames.push_back(valueNames[operands[i]]);
else
symbolValueNames.push_back(valueNames[operands[i]]);
}
auto printValueName = [&](unsigned pos, bool isSymbol) {
unsigned index = isSymbol ? numDims + pos : pos;
assert(index < operands.size());
if (isSymbol)
os << "symbol(";
printValueID(operands[index]);
if (isSymbol)
os << ')';
};
interleaveComma(map.getResults(), [&](AffineExpr expr) {
printAffineExpr(expr, dimValueNames, symbolValueNames);
printAffineExpr(expr, printValueName);
});
}

View File

@ -165,4 +165,21 @@ func @test6(%arg0 : index, %arg1 : index) {
}
}
return
}
// -----
// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0 + 1)
// Test with operands without special SSA name.
func @test7() {
%0 = alloc() : memref<10xf32>
affine.for %i0 = 0 to 10 {
%1 = affine.apply (d1) -> (d1 + 1)(%i0)
%2 = affine.load %0[%1] : memref<10xf32>
affine.store %2, %0[%1] : memref<10xf32>
// CHECK: affine.load %0[%1] : memref<10xf32>
// CHECK: affine.store %2, %0[%1] : memref<10xf32>
}
return
}