Add rewrite pattern to compose maps into affine load/stores

- add canonicalization pattern to compose maps into affine loads/stores;
  templatize the pattern and reuse it for affine.apply as well

- rename getIndices -> getMapOperands() (getIndices is confusing since
  these are no longer the indices themselves but operands to the map
  whose results are the indices). This also makes the accessor uniform
  across affine.apply/load/store. Change arg names on the affine
  load/store builder to avoid confusion. Drop an unused confusing build
  method on AffineStoreOp.

- update incomplete doc comment for canonicalizeMapAndOperands (this was
  missed from a previous update).

Addresses issue tensorflow/mlir#121

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>

Closes tensorflow/mlir#122

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/122 from bondhugula:compose-load-store e71de1771e56a85c4282c10cb43f30cef0701c4f
PiperOrigin-RevId: 269619540
This commit is contained in:
Uday Bondhugula 2019-09-17 11:49:14 -07:00 committed by A. Unique TensorFlower
parent 62e1faa6f6
commit bd7de6d4df
8 changed files with 133 additions and 54 deletions

View File

@ -83,6 +83,8 @@ public:
static StringRef getOperationName() { return "affine.apply"; }
operand_range getMapOperands() { return getOperands(); }
// Hooks to customize behavior of this op.
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);
@ -400,9 +402,12 @@ public:
/// Builds an affine load op with the specified map and operands.
static void build(Builder *builder, OperationState *result, AffineMap map,
ArrayRef<Value *> operands);
/// Builds an affine load op an identify map and operands.
/// Builds an affine load op with an identity map and operands.
static void build(Builder *builder, OperationState *result, Value *memref,
ArrayRef<Value *> indices = {});
/// Builds an affine load op with the specified map and its operands.
static void build(Builder *builder, OperationState *result, Value *memref,
AffineMap map, ArrayRef<Value *> mapOperands);
/// Returns the operand index of the memref.
unsigned getMemRefOperandIndex() { return 0; }
@ -415,7 +420,7 @@ public:
}
/// Get affine map operands.
operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); }
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
/// Returns the affine map used to index the memref for this operation.
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
@ -462,14 +467,14 @@ class AffineStoreOp : public Op<AffineStoreOp, OpTrait::ZeroResult,
public:
using Op::Op;
/// Builds an affine store operation with the specified map and operands.
static void build(Builder *builder, OperationState *result,
Value *valueToStore, AffineMap map,
ArrayRef<Value *> operands);
/// Builds an affine store operation with an identity map and operands.
/// Builds an affine store operation with the provided indices (identity map).
static void build(Builder *builder, OperationState *result,
Value *valueToStore, Value *memref,
ArrayRef<Value *> operands);
ArrayRef<Value *> indices);
/// Builds an affine store operation with the specified map and its operands.
static void build(Builder *builder, OperationState *result,
Value *valueToStore, Value *memref, AffineMap map,
ArrayRef<Value *> mapOperands);
/// Get value to be stored by store operation.
Value *getValueToStore() { return getOperand(0); }
@ -486,7 +491,7 @@ public:
}
/// Get affine map operands.
operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); }
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
/// Returns the affine map used to index the memref for this operation.
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
@ -521,6 +526,9 @@ bool isValidSymbol(Value *value);
/// Modifies both `map` and `operands` in-place so as to:
/// 1. drop duplicate operands
/// 2. drop unused dims and symbols from map
/// 3. promote valid symbols to symbolic operands in case they appeared as
/// dimensional operands
/// 4. propagate constant operands and drop them
void canonicalizeMapAndOperands(AffineMap *map,
llvm::SmallVectorImpl<Value *> *operands);
/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does

View File

@ -236,7 +236,7 @@ static bool isContiguousAccess(Value *iv, LoadOrStoreOp memoryOp,
int uniqueVaryingIndexAlongIv = -1;
auto accessMap = memoryOp.getAffineMap();
SmallVector<Value *, 4> mapOperands(memoryOp.getIndices());
SmallVector<Value *, 4> mapOperands(memoryOp.getMapOperands());
unsigned numDims = accessMap.getNumDims();
for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) {
// Gather map operands used result expr 'i' in 'exprOperands'.

View File

@ -847,7 +847,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
opInst = loadOrStoreOpInst;
auto loadMemrefType = loadOp.getMemRefType();
indices.reserve(loadMemrefType.getRank());
for (auto *index : loadOp.getIndices()) {
for (auto *index : loadOp.getMapOperands()) {
indices.push_back(index);
}
} else {
@ -857,7 +857,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
memref = storeOp.getMemRef();
auto storeMemrefType = storeOp.getMemRefType();
indices.reserve(storeMemrefType.getRank());
for (auto *index : storeOp.getIndices()) {
for (auto *index : storeOp.getMapOperands()) {
indices.push_back(index);
}
}

View File

@ -698,30 +698,63 @@ void mlir::canonicalizeSetAndOperands(
}
namespace {
/// Simplify AffineApply operations.
/// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
/// maps that supply results into them.
///
struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
template <typename AffineOpTy>
struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
using OpRewritePattern<AffineOpTy>::OpRewritePattern;
PatternMatchResult matchAndRewrite(AffineApplyOp apply,
void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
AffineMap map, ArrayRef<Value *> mapOperands) const;
PatternMatchResult matchAndRewrite(AffineOpTy affineOp,
PatternRewriter &rewriter) const override {
auto map = apply.getAffineMap();
static_assert(std::is_same<AffineOpTy, AffineLoadOp>::value ||
std::is_same<AffineOpTy, AffineStoreOp>::value ||
std::is_same<AffineOpTy, AffineApplyOp>::value,
"affine load/store/apply op expected");
auto map = affineOp.getAffineMap();
AffineMap oldMap = map;
SmallVector<Value *, 8> resultOperands(apply.getOperands());
auto oldOperands = affineOp.getMapOperands();
SmallVector<Value *, 8> resultOperands(oldOperands);
composeAffineMapAndOperands(&map, &resultOperands);
if (map == oldMap)
return matchFailure();
if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
resultOperands.begin()))
return this->matchFailure();
rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, resultOperands);
return matchSuccess();
replaceAffineOp(rewriter, affineOp, map, resultOperands);
return this->matchSuccess();
}
};
// Specialize the template to account for the different build signatures for
// affine load, store, and apply ops.
template <>
void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
ArrayRef<Value *> mapOperands) const {
rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
mapOperands);
}
template <>
void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
ArrayRef<Value *> mapOperands) const {
rewriter.replaceOpWithNewOp<AffineStoreOp>(
store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
}
template <>
void SimplifyAffineOp<AffineApplyOp>::replaceAffineOp(
PatternRewriter &rewriter, AffineApplyOp apply, AffineMap map,
ArrayRef<Value *> mapOperands) const {
rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, mapOperands);
}
} // end anonymous namespace.
void AffineApplyOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SimplifyAffineApply>(context);
results.insert<SimplifyAffineOp<AffineApplyOp>>(context);
}
//===----------------------------------------------------------------------===//
@ -1689,6 +1722,7 @@ void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void AffineLoadOp::build(Builder *builder, OperationState *result,
AffineMap map, ArrayRef<Value *> operands) {
assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
result->addOperands(operands);
if (map)
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
@ -1697,17 +1731,25 @@ void AffineLoadOp::build(Builder *builder, OperationState *result,
}
void AffineLoadOp::build(Builder *builder, OperationState *result,
Value *memref, ArrayRef<Value *> indices) {
Value *memref, AffineMap map,
ArrayRef<Value *> mapOperands) {
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
result->addOperands(memref);
result->addOperands(indices);
result->addOperands(mapOperands);
auto memrefType = memref->getType().cast<MemRefType>();
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
result->types.push_back(memrefType.getElementType());
}
void AffineLoadOp::build(Builder *builder, OperationState *result,
Value *memref, ArrayRef<Value *> indices) {
auto memrefType = memref->getType().cast<MemRefType>();
auto rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
auto map = rank ? builder->getMultiDimIdentityMap(rank)
: builder->getEmptyAffineMap();
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
result->types.push_back(memrefType.getElementType());
build(builder, result, memref, map, indices);
}
ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) {
@ -1733,7 +1775,7 @@ void AffineLoadOp::print(OpAsmPrinter *p) {
*p << "affine.load " << *getMemRef() << '[';
AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
if (mapAttr) {
SmallVector<Value *, 2> operands(getIndices());
SmallVector<Value *, 2> operands(getMapOperands());
p->printAffineMapOfSSAIds(mapAttr, operands);
}
*p << ']';
@ -1759,7 +1801,7 @@ LogicalResult AffineLoadOp::verify() {
"expects the number of subscripts to be equal to memref rank");
}
for (auto *idx : getIndices()) {
for (auto *idx : getMapOperands()) {
if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type");
if (!isValidAffineIndexOperand(idx))
@ -1772,6 +1814,7 @@ void AffineLoadOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
/// load(memrefcast) -> load
results.insert<MemRefCastFolder>(getOperationName(), context);
results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
}
//===----------------------------------------------------------------------===//
@ -1779,27 +1822,26 @@ void AffineLoadOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
void AffineStoreOp::build(Builder *builder, OperationState *result,
Value *valueToStore, AffineMap map,
ArrayRef<Value *> operands) {
result->addOperands(valueToStore);
result->addOperands(operands);
if (map)
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
}
void AffineStoreOp::build(Builder *builder, OperationState *result,
Value *valueToStore, Value *memref,
ArrayRef<Value *> operands) {
Value *valueToStore, Value *memref, AffineMap map,
ArrayRef<Value *> mapOperands) {
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
result->addOperands(valueToStore);
result->addOperands(memref);
result->addOperands(operands);
result->addOperands(mapOperands);
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
}
// Use identity map.
void AffineStoreOp::build(Builder *builder, OperationState *result,
Value *valueToStore, Value *memref,
ArrayRef<Value *> indices) {
auto memrefType = memref->getType().cast<MemRefType>();
auto rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
auto map = rank ? builder->getMultiDimIdentityMap(rank)
: builder->getEmptyAffineMap();
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
build(builder, result, valueToStore, memref, map, indices);
}
ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) {
@ -1828,7 +1870,7 @@ void AffineStoreOp::print(OpAsmPrinter *p) {
*p << ", " << *getMemRef() << '[';
AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
if (mapAttr) {
SmallVector<Value *, 2> operands(getIndices());
SmallVector<Value *, 2> operands(getMapOperands());
p->printAffineMapOfSSAIds(mapAttr, operands);
}
*p << ']';
@ -1855,7 +1897,7 @@ LogicalResult AffineStoreOp::verify() {
"expects the number of subscripts to be equal to memref rank");
}
for (auto *idx : getIndices()) {
for (auto *idx : getMapOperands()) {
if (!idx->getType().isIndex())
return emitOpError("index to store must have 'index' type");
if (!isValidAffineIndexOperand(idx))
@ -1868,6 +1910,7 @@ void AffineStoreOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
/// load(memrefcast) -> load
results.insert<MemRefCastFolder>(getOperationName(), context);
results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
}
#define GET_OP_CLASSES

View File

@ -403,7 +403,7 @@ public:
virtual PatternMatchResult
matchAndRewrite(AffineLoadOp op, PatternRewriter &rewriter) const override {
// Expand affine map from 'affineLoadOp'.
SmallVector<Value *, 8> indices(op.getIndices());
SmallVector<Value *, 8> indices(op.getMapOperands());
auto maybeExpandedMap =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!maybeExpandedMap)
@ -425,7 +425,7 @@ public:
virtual PatternMatchResult
matchAndRewrite(AffineStoreOp op, PatternRewriter &rewriter) const override {
// Expand affine map from 'affineStoreOp'.
SmallVector<Value *, 8> indices(op.getIndices());
SmallVector<Value *, 8> indices(op.getMapOperands());
auto maybeExpandedMap =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!maybeExpandedMap)

View File

@ -814,14 +814,15 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv,
// as needed by various targets.
if (auto load = dyn_cast<AffineLoadOp>(opInst)) {
OpBuilder b(opInst);
SmallVector<Value *, 4> mapOperands(load.getIndices());
SmallVector<Value *, 4> mapOperands(load.getMapOperands());
SmallVector<Value *, 8> indices;
indices.reserve(load.getMemRefType().getRank());
if (load.getAffineMap() !=
b.getMultiDimIdentityMap(load.getMemRefType().getRank())) {
computeMemoryOpIndices(opInst, load.getAffineMap(), mapOperands, indices);
} else {
indices.append(load.getIndices().begin(), load.getIndices().end());
indices.append(load.getMapOperands().begin(),
load.getMapOperands().end());
}
auto permutationMap =
makePermutationMap(opInst, indices, state->strategy->loopToVectorDim);
@ -1038,7 +1039,7 @@ static Operation *vectorizeOneOperation(Operation *opInst,
auto *value = store.getValueToStore();
auto *vectorValue = vectorizeOperand(value, opInst, state);
SmallVector<Value *, 4> mapOperands(store.getIndices());
SmallVector<Value *, 4> mapOperands(store.getMapOperands());
SmallVector<Value *, 8> indices;
indices.reserve(store.getMemRefType().getRank());
if (store.getAffineMap() !=
@ -1046,7 +1047,8 @@ static Operation *vectorizeOneOperation(Operation *opInst,
computeMemoryOpIndices(opInst, store.getAffineMap(), mapOperands,
indices);
} else {
indices.append(store.getIndices().begin(), store.getIndices().end());
indices.append(store.getMapOperands().begin(),
store.getMapOperands().end());
}
auto permutationMap =

View File

@ -424,6 +424,7 @@ func @fold_empty_loop() {
}
return
}
// CHECK: return
// -----
@ -476,3 +477,29 @@ func @canonicalize_bounds(%M : index, %N : index) {
}
return
}
// -----
// Compose maps into affine load and store ops.
// CHECK-DAG: #map{{[0-9]+}} = (d0) -> (d0 + 1)
// CHECK-LABEL: @compose_into_affine_load_store
func @compose_into_affine_load_store(%A : memref<1024xf32>, %u : index) {
%cf1 = constant 1.0 : f32
// CHECK: affine.for %[[IV:.*]] = 0 to 1024
affine.for %i = 0 to 1024 {
// Make sure the unused operand (%u below) gets dropped as well.
%idx = affine.apply (d0, d1) -> (d0 + 1) (%i, %u)
affine.load %A[%idx] : memref<1024xf32>
affine.store %cf1, %A[%idx] : memref<1024xf32>
// CHECK-NEXT: affine.load %{{.*}}[%[[IV]] + 1]
// CHECK-NEXT: affine.store %cst, %{{.*}}[%[[IV]] + 1]
// Map remains the same, but operand changes on composition.
%copy = affine.apply (d0) -> (d0) (%i)
affine.load %A[%copy] : memref<1024xf32>
// CHECK-NEXT: affine.load %{{.*}}[%[[IV]]]
}
return
}

View File

@ -256,13 +256,12 @@ func @xor_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
// CHECK-LABEL: func @memref_cast_folding
func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 {
// CHECK-NOT: memref_cast
%1 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
// CHECK-NEXT: %c0 = constant 0 : index
%c0 = constant 0 : index
// CHECK-NOT: dim
%dim = dim %1, 0 : memref<? x f32>
// CHECK: affine.load %arg0[%c4 - 1]
// CHECK-NEXT: affine.load %arg0[3]
affine.load %1[%dim - 1] : memref<?xf32>
// CHECK-NEXT: store %arg1, %arg0[%c0] : memref<4xf32>