forked from OSchip/llvm-project
Support printing SSA ids in affine.load/store which do not have special names.
PiperOrigin-RevId: 254997746
This commit is contained in:
parent
66ed7d6d83
commit
91f27d025b
|
@ -87,6 +87,8 @@ public:
|
|||
|
||||
/// Prints an affine map of SSA ids, where SSA id names are used in place
|
||||
/// of dims/symbols.
|
||||
/// Operand values must come from single-result sources, and be valid
|
||||
/// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
|
||||
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
|
||||
ArrayRef<Value *> operands) = 0;
|
||||
|
||||
|
@ -380,6 +382,8 @@ public:
|
|||
}
|
||||
|
||||
/// Parses an affine map attribute where dims and symbols are SSA operands.
|
||||
/// Operand values must come from single-result sources, and be valid
|
||||
/// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
|
||||
virtual ParseResult
|
||||
parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
|
||||
StringRef attrName,
|
||||
|
|
|
@ -347,8 +347,9 @@ public:
|
|||
void printLocation(LocationAttr loc);
|
||||
|
||||
void printAffineMap(AffineMap map);
|
||||
void printAffineExpr(AffineExpr expr, ArrayRef<StringRef> dimValueNames = {},
|
||||
ArrayRef<StringRef> symbolValueNames = {});
|
||||
void printAffineExpr(
|
||||
AffineExpr expr,
|
||||
llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
|
||||
void printAffineConstraint(AffineExpr expr, bool isEq);
|
||||
void printIntegerSet(IntegerSet set);
|
||||
|
||||
|
@ -370,10 +371,9 @@ protected:
|
|||
Weak, // + and -
|
||||
Strong, // All other binary operators.
|
||||
};
|
||||
void printAffineExprInternal(AffineExpr expr,
|
||||
BindingStrength enclosingTightness,
|
||||
ArrayRef<StringRef> dimValueNames = {},
|
||||
ArrayRef<StringRef> symbolValueNames = {});
|
||||
void printAffineExprInternal(
|
||||
AffineExpr expr, BindingStrength enclosingTightness,
|
||||
llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -921,30 +921,28 @@ void ModulePrinter::printType(Type type) {
|
|||
// Affine expressions and maps
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ModulePrinter::printAffineExpr(AffineExpr expr,
|
||||
ArrayRef<StringRef> dimValueNames,
|
||||
ArrayRef<StringRef> symbolValueNames) {
|
||||
printAffineExprInternal(expr, BindingStrength::Weak, dimValueNames,
|
||||
symbolValueNames);
|
||||
void ModulePrinter::printAffineExpr(
|
||||
AffineExpr expr, llvm::function_ref<void(unsigned, bool)> printValueName) {
|
||||
printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
|
||||
}
|
||||
|
||||
void ModulePrinter::printAffineExprInternal(
|
||||
AffineExpr expr, BindingStrength enclosingTightness,
|
||||
ArrayRef<StringRef> dimValueNames, ArrayRef<StringRef> symbolValueNames) {
|
||||
llvm::function_ref<void(unsigned, bool)> printValueName) {
|
||||
const char *binopSpelling = nullptr;
|
||||
switch (expr.getKind()) {
|
||||
case AffineExprKind::SymbolId: {
|
||||
unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
|
||||
if (pos < symbolValueNames.size())
|
||||
os << "symbol(%" << symbolValueNames[pos] << ')';
|
||||
if (printValueName)
|
||||
printValueName(pos, /*isSymbol=*/true);
|
||||
else
|
||||
os << 's' << pos;
|
||||
return;
|
||||
}
|
||||
case AffineExprKind::DimId: {
|
||||
unsigned pos = expr.cast<AffineDimExpr>().getPosition();
|
||||
if (pos < dimValueNames.size())
|
||||
os << '%' << dimValueNames[pos];
|
||||
if (printValueName)
|
||||
printValueName(pos, /*isSymbol=*/false);
|
||||
else
|
||||
os << 'd' << pos;
|
||||
return;
|
||||
|
@ -982,16 +980,14 @@ void ModulePrinter::printAffineExprInternal(
|
|||
auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
|
||||
if (rhsConst && rhsConst.getValue() == -1) {
|
||||
os << "-";
|
||||
printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames,
|
||||
symbolValueNames);
|
||||
printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
|
||||
return;
|
||||
}
|
||||
|
||||
printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames,
|
||||
symbolValueNames);
|
||||
printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
|
||||
|
||||
os << binopSpelling;
|
||||
printAffineExprInternal(rhsExpr, BindingStrength::Strong, dimValueNames,
|
||||
symbolValueNames);
|
||||
printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
|
||||
|
||||
if (enclosingTightness == BindingStrength::Strong)
|
||||
os << ')';
|
||||
|
@ -1009,15 +1005,15 @@ void ModulePrinter::printAffineExprInternal(
|
|||
AffineExpr rrhsExpr = rhs.getRHS();
|
||||
if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
|
||||
if (rrhs.getValue() == -1) {
|
||||
printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
|
||||
symbolValueNames);
|
||||
printAffineExprInternal(lhsExpr, BindingStrength::Weak,
|
||||
printValueName);
|
||||
os << " - ";
|
||||
if (rhs.getLHS().getKind() == AffineExprKind::Add) {
|
||||
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
|
||||
dimValueNames, symbolValueNames);
|
||||
printValueName);
|
||||
} else {
|
||||
printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
|
||||
dimValueNames, symbolValueNames);
|
||||
printValueName);
|
||||
}
|
||||
|
||||
if (enclosingTightness == BindingStrength::Strong)
|
||||
|
@ -1026,11 +1022,11 @@ void ModulePrinter::printAffineExprInternal(
|
|||
}
|
||||
|
||||
if (rrhs.getValue() < -1) {
|
||||
printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
|
||||
symbolValueNames);
|
||||
printAffineExprInternal(lhsExpr, BindingStrength::Weak,
|
||||
printValueName);
|
||||
os << " - ";
|
||||
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
|
||||
dimValueNames, symbolValueNames);
|
||||
printValueName);
|
||||
os << " * " << -rrhs.getValue();
|
||||
if (enclosingTightness == BindingStrength::Strong)
|
||||
os << ')';
|
||||
|
@ -1043,8 +1039,7 @@ void ModulePrinter::printAffineExprInternal(
|
|||
// Pretty print addition to a negative number as a subtraction.
|
||||
if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
|
||||
if (rhsConst.getValue() < 0) {
|
||||
printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
|
||||
symbolValueNames);
|
||||
printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
|
||||
os << " - " << -rhsConst.getValue();
|
||||
if (enclosingTightness == BindingStrength::Strong)
|
||||
os << ')';
|
||||
|
@ -1052,11 +1047,10 @@ void ModulePrinter::printAffineExprInternal(
|
|||
}
|
||||
}
|
||||
|
||||
printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
|
||||
symbolValueNames);
|
||||
printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
|
||||
|
||||
os << " + ";
|
||||
printAffineExprInternal(rhsExpr, BindingStrength::Weak, dimValueNames,
|
||||
symbolValueNames);
|
||||
printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
|
||||
|
||||
if (enclosingTightness == BindingStrength::Strong)
|
||||
os << ')';
|
||||
|
@ -1242,16 +1236,18 @@ public:
|
|||
ArrayRef<Value *> operands) {
|
||||
AffineMap map = mapAttr.getValue();
|
||||
unsigned numDims = map.getNumDims();
|
||||
SmallVector<StringRef, 2> dimValueNames;
|
||||
SmallVector<StringRef, 1> symbolValueNames;
|
||||
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
|
||||
if (i < numDims)
|
||||
dimValueNames.push_back(valueNames[operands[i]]);
|
||||
else
|
||||
symbolValueNames.push_back(valueNames[operands[i]]);
|
||||
}
|
||||
auto printValueName = [&](unsigned pos, bool isSymbol) {
|
||||
unsigned index = isSymbol ? numDims + pos : pos;
|
||||
assert(index < operands.size());
|
||||
if (isSymbol)
|
||||
os << "symbol(";
|
||||
printValueID(operands[index]);
|
||||
if (isSymbol)
|
||||
os << ')';
|
||||
};
|
||||
|
||||
interleaveComma(map.getResults(), [&](AffineExpr expr) {
|
||||
printAffineExpr(expr, dimValueNames, symbolValueNames);
|
||||
printAffineExpr(expr, printValueName);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -165,4 +165,21 @@ func @test6(%arg0 : index, %arg1 : index) {
|
|||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: [[MAP0:#map[0-9]+]] = (d0) -> (d0 + 1)
|
||||
|
||||
// Test with operands without special SSA name.
|
||||
func @test7() {
|
||||
%0 = alloc() : memref<10xf32>
|
||||
affine.for %i0 = 0 to 10 {
|
||||
%1 = affine.apply (d1) -> (d1 + 1)(%i0)
|
||||
%2 = affine.load %0[%1] : memref<10xf32>
|
||||
affine.store %2, %0[%1] : memref<10xf32>
|
||||
// CHECK: affine.load %0[%1] : memref<10xf32>
|
||||
// CHECK: affine.store %2, %0[%1] : memref<10xf32>
|
||||
}
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue