forked from OSchip/llvm-project
Add rewrite pattern to compose maps into affine load/stores
- add canonicalization pattern to compose maps into affine loads/stores; templatize the pattern and reuse it for affine.apply as well - rename getIndices -> getMapOperands() (getIndices is confusing since these are no longer the indices themselves but operands to the map whose results are the indices). This also makes the accessor uniform across affine.apply/load/store. Change arg names on the affine load/store builder to avoid confusion. Drop an unused confusing build method on AffineStoreOp. - update incomplete doc comment for canonicalizeMapAndOperands (this was missed from a previous update). Addresses issue tensorflow/mlir#121 Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes tensorflow/mlir#122 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/122 from bondhugula:compose-load-store e71de1771e56a85c4282c10cb43f30cef0701c4f PiperOrigin-RevId: 269619540
This commit is contained in:
parent
62e1faa6f6
commit
bd7de6d4df
|
@ -83,6 +83,8 @@ public:
|
|||
|
||||
static StringRef getOperationName() { return "affine.apply"; }
|
||||
|
||||
operand_range getMapOperands() { return getOperands(); }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static ParseResult parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p);
|
||||
|
@ -400,9 +402,12 @@ public:
|
|||
/// Builds an affine load op with the specified map and operands.
|
||||
static void build(Builder *builder, OperationState *result, AffineMap map,
|
||||
ArrayRef<Value *> operands);
|
||||
/// Builds an affine load op an identify map and operands.
|
||||
/// Builds an affine load op with an identity map and operands.
|
||||
static void build(Builder *builder, OperationState *result, Value *memref,
|
||||
ArrayRef<Value *> indices = {});
|
||||
/// Builds an affine load op with the specified map and its operands.
|
||||
static void build(Builder *builder, OperationState *result, Value *memref,
|
||||
AffineMap map, ArrayRef<Value *> mapOperands);
|
||||
|
||||
/// Returns the operand index of the memref.
|
||||
unsigned getMemRefOperandIndex() { return 0; }
|
||||
|
@ -415,7 +420,7 @@ public:
|
|||
}
|
||||
|
||||
/// Get affine map operands.
|
||||
operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); }
|
||||
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 1); }
|
||||
|
||||
/// Returns the affine map used to index the memref for this operation.
|
||||
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
|
||||
|
@ -462,14 +467,14 @@ class AffineStoreOp : public Op<AffineStoreOp, OpTrait::ZeroResult,
|
|||
public:
|
||||
using Op::Op;
|
||||
|
||||
/// Builds an affine store operation with the specified map and operands.
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
Value *valueToStore, AffineMap map,
|
||||
ArrayRef<Value *> operands);
|
||||
/// Builds an affine store operation with an identity map and operands.
|
||||
/// Builds an affine store operation with the provided indices (identity map).
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
Value *valueToStore, Value *memref,
|
||||
ArrayRef<Value *> operands);
|
||||
ArrayRef<Value *> indices);
|
||||
/// Builds an affine store operation with the specified map and its operands.
|
||||
static void build(Builder *builder, OperationState *result,
|
||||
Value *valueToStore, Value *memref, AffineMap map,
|
||||
ArrayRef<Value *> mapOperands);
|
||||
|
||||
/// Get value to be stored by store operation.
|
||||
Value *getValueToStore() { return getOperand(0); }
|
||||
|
@ -486,7 +491,7 @@ public:
|
|||
}
|
||||
|
||||
/// Get affine map operands.
|
||||
operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); }
|
||||
operand_range getMapOperands() { return llvm::drop_begin(getOperands(), 2); }
|
||||
|
||||
/// Returns the affine map used to index the memref for this operation.
|
||||
AffineMap getAffineMap() { return getAffineMapAttr().getValue(); }
|
||||
|
@ -521,6 +526,9 @@ bool isValidSymbol(Value *value);
|
|||
/// Modifies both `map` and `operands` in-place so as to:
|
||||
/// 1. drop duplicate operands
|
||||
/// 2. drop unused dims and symbols from map
|
||||
/// 3. promote valid symbols to symbolic operands in case they appeared as
|
||||
/// dimensional operands
|
||||
/// 4. propagate constant operands and drop them
|
||||
void canonicalizeMapAndOperands(AffineMap *map,
|
||||
llvm::SmallVectorImpl<Value *> *operands);
|
||||
/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does
|
||||
|
|
|
@ -236,7 +236,7 @@ static bool isContiguousAccess(Value *iv, LoadOrStoreOp memoryOp,
|
|||
|
||||
int uniqueVaryingIndexAlongIv = -1;
|
||||
auto accessMap = memoryOp.getAffineMap();
|
||||
SmallVector<Value *, 4> mapOperands(memoryOp.getIndices());
|
||||
SmallVector<Value *, 4> mapOperands(memoryOp.getMapOperands());
|
||||
unsigned numDims = accessMap.getNumDims();
|
||||
for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) {
|
||||
// Gather map operands used result expr 'i' in 'exprOperands'.
|
||||
|
|
|
@ -847,7 +847,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
|
|||
opInst = loadOrStoreOpInst;
|
||||
auto loadMemrefType = loadOp.getMemRefType();
|
||||
indices.reserve(loadMemrefType.getRank());
|
||||
for (auto *index : loadOp.getIndices()) {
|
||||
for (auto *index : loadOp.getMapOperands()) {
|
||||
indices.push_back(index);
|
||||
}
|
||||
} else {
|
||||
|
@ -857,7 +857,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
|
|||
memref = storeOp.getMemRef();
|
||||
auto storeMemrefType = storeOp.getMemRefType();
|
||||
indices.reserve(storeMemrefType.getRank());
|
||||
for (auto *index : storeOp.getIndices()) {
|
||||
for (auto *index : storeOp.getMapOperands()) {
|
||||
indices.push_back(index);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -698,30 +698,63 @@ void mlir::canonicalizeSetAndOperands(
|
|||
}
|
||||
|
||||
namespace {
|
||||
/// Simplify AffineApply operations.
|
||||
/// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
|
||||
/// maps that supply results into them.
|
||||
///
|
||||
struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
|
||||
using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
|
||||
template <typename AffineOpTy>
|
||||
struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
|
||||
using OpRewritePattern<AffineOpTy>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineApplyOp apply,
|
||||
void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
|
||||
AffineMap map, ArrayRef<Value *> mapOperands) const;
|
||||
|
||||
PatternMatchResult matchAndRewrite(AffineOpTy affineOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto map = apply.getAffineMap();
|
||||
|
||||
static_assert(std::is_same<AffineOpTy, AffineLoadOp>::value ||
|
||||
std::is_same<AffineOpTy, AffineStoreOp>::value ||
|
||||
std::is_same<AffineOpTy, AffineApplyOp>::value,
|
||||
"affine load/store/apply op expected");
|
||||
auto map = affineOp.getAffineMap();
|
||||
AffineMap oldMap = map;
|
||||
SmallVector<Value *, 8> resultOperands(apply.getOperands());
|
||||
auto oldOperands = affineOp.getMapOperands();
|
||||
SmallVector<Value *, 8> resultOperands(oldOperands);
|
||||
composeAffineMapAndOperands(&map, &resultOperands);
|
||||
if (map == oldMap)
|
||||
return matchFailure();
|
||||
if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
|
||||
resultOperands.begin()))
|
||||
return this->matchFailure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, resultOperands);
|
||||
return matchSuccess();
|
||||
replaceAffineOp(rewriter, affineOp, map, resultOperands);
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Specialize the template to account for the different build signatures for
|
||||
// affine load, store, and apply ops.
|
||||
template <>
|
||||
void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
|
||||
PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
|
||||
ArrayRef<Value *> mapOperands) const {
|
||||
rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
|
||||
mapOperands);
|
||||
}
|
||||
template <>
|
||||
void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
|
||||
PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
|
||||
ArrayRef<Value *> mapOperands) const {
|
||||
rewriter.replaceOpWithNewOp<AffineStoreOp>(
|
||||
store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
|
||||
}
|
||||
template <>
|
||||
void SimplifyAffineOp<AffineApplyOp>::replaceAffineOp(
|
||||
PatternRewriter &rewriter, AffineApplyOp apply, AffineMap map,
|
||||
ArrayRef<Value *> mapOperands) const {
|
||||
rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, mapOperands);
|
||||
}
|
||||
} // end anonymous namespace.
|
||||
|
||||
void AffineApplyOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<SimplifyAffineApply>(context);
|
||||
results.insert<SimplifyAffineOp<AffineApplyOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1689,6 +1722,7 @@ void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
|
||||
void AffineLoadOp::build(Builder *builder, OperationState *result,
|
||||
AffineMap map, ArrayRef<Value *> operands) {
|
||||
assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
|
||||
result->addOperands(operands);
|
||||
if (map)
|
||||
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
|
||||
|
@ -1697,17 +1731,25 @@ void AffineLoadOp::build(Builder *builder, OperationState *result,
|
|||
}
|
||||
|
||||
void AffineLoadOp::build(Builder *builder, OperationState *result,
|
||||
Value *memref, ArrayRef<Value *> indices) {
|
||||
Value *memref, AffineMap map,
|
||||
ArrayRef<Value *> mapOperands) {
|
||||
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
|
||||
result->addOperands(memref);
|
||||
result->addOperands(indices);
|
||||
result->addOperands(mapOperands);
|
||||
auto memrefType = memref->getType().cast<MemRefType>();
|
||||
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
|
||||
result->types.push_back(memrefType.getElementType());
|
||||
}
|
||||
|
||||
void AffineLoadOp::build(Builder *builder, OperationState *result,
|
||||
Value *memref, ArrayRef<Value *> indices) {
|
||||
auto memrefType = memref->getType().cast<MemRefType>();
|
||||
auto rank = memrefType.getRank();
|
||||
// Create identity map for memrefs with at least one dimension or () -> ()
|
||||
// for zero-dimensional memrefs.
|
||||
auto map = rank ? builder->getMultiDimIdentityMap(rank)
|
||||
: builder->getEmptyAffineMap();
|
||||
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
|
||||
result->types.push_back(memrefType.getElementType());
|
||||
build(builder, result, memref, map, indices);
|
||||
}
|
||||
|
||||
ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
|
@ -1733,7 +1775,7 @@ void AffineLoadOp::print(OpAsmPrinter *p) {
|
|||
*p << "affine.load " << *getMemRef() << '[';
|
||||
AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
|
||||
if (mapAttr) {
|
||||
SmallVector<Value *, 2> operands(getIndices());
|
||||
SmallVector<Value *, 2> operands(getMapOperands());
|
||||
p->printAffineMapOfSSAIds(mapAttr, operands);
|
||||
}
|
||||
*p << ']';
|
||||
|
@ -1759,7 +1801,7 @@ LogicalResult AffineLoadOp::verify() {
|
|||
"expects the number of subscripts to be equal to memref rank");
|
||||
}
|
||||
|
||||
for (auto *idx : getIndices()) {
|
||||
for (auto *idx : getMapOperands()) {
|
||||
if (!idx->getType().isIndex())
|
||||
return emitOpError("index to load must have 'index' type");
|
||||
if (!isValidAffineIndexOperand(idx))
|
||||
|
@ -1772,6 +1814,7 @@ void AffineLoadOp::getCanonicalizationPatterns(
|
|||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
/// load(memrefcast) -> load
|
||||
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||
results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1779,27 +1822,26 @@ void AffineLoadOp::getCanonicalizationPatterns(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AffineStoreOp::build(Builder *builder, OperationState *result,
|
||||
Value *valueToStore, AffineMap map,
|
||||
ArrayRef<Value *> operands) {
|
||||
result->addOperands(valueToStore);
|
||||
result->addOperands(operands);
|
||||
if (map)
|
||||
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
|
||||
}
|
||||
|
||||
void AffineStoreOp::build(Builder *builder, OperationState *result,
|
||||
Value *valueToStore, Value *memref,
|
||||
ArrayRef<Value *> operands) {
|
||||
Value *valueToStore, Value *memref, AffineMap map,
|
||||
ArrayRef<Value *> mapOperands) {
|
||||
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
|
||||
result->addOperands(valueToStore);
|
||||
result->addOperands(memref);
|
||||
result->addOperands(operands);
|
||||
result->addOperands(mapOperands);
|
||||
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
|
||||
}
|
||||
|
||||
// Use identity map.
|
||||
void AffineStoreOp::build(Builder *builder, OperationState *result,
|
||||
Value *valueToStore, Value *memref,
|
||||
ArrayRef<Value *> indices) {
|
||||
auto memrefType = memref->getType().cast<MemRefType>();
|
||||
auto rank = memrefType.getRank();
|
||||
// Create identity map for memrefs with at least one dimension or () -> ()
|
||||
// for zero-dimensional memrefs.
|
||||
auto map = rank ? builder->getMultiDimIdentityMap(rank)
|
||||
: builder->getEmptyAffineMap();
|
||||
result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
|
||||
build(builder, result, valueToStore, memref, map, indices);
|
||||
}
|
||||
|
||||
ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
|
@ -1828,7 +1870,7 @@ void AffineStoreOp::print(OpAsmPrinter *p) {
|
|||
*p << ", " << *getMemRef() << '[';
|
||||
AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
|
||||
if (mapAttr) {
|
||||
SmallVector<Value *, 2> operands(getIndices());
|
||||
SmallVector<Value *, 2> operands(getMapOperands());
|
||||
p->printAffineMapOfSSAIds(mapAttr, operands);
|
||||
}
|
||||
*p << ']';
|
||||
|
@ -1855,7 +1897,7 @@ LogicalResult AffineStoreOp::verify() {
|
|||
"expects the number of subscripts to be equal to memref rank");
|
||||
}
|
||||
|
||||
for (auto *idx : getIndices()) {
|
||||
for (auto *idx : getMapOperands()) {
|
||||
if (!idx->getType().isIndex())
|
||||
return emitOpError("index to store must have 'index' type");
|
||||
if (!isValidAffineIndexOperand(idx))
|
||||
|
@ -1868,6 +1910,7 @@ void AffineStoreOp::getCanonicalizationPatterns(
|
|||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
/// load(memrefcast) -> load
|
||||
results.insert<MemRefCastFolder>(getOperationName(), context);
|
||||
results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
|
|
@ -403,7 +403,7 @@ public:
|
|||
virtual PatternMatchResult
|
||||
matchAndRewrite(AffineLoadOp op, PatternRewriter &rewriter) const override {
|
||||
// Expand affine map from 'affineLoadOp'.
|
||||
SmallVector<Value *, 8> indices(op.getIndices());
|
||||
SmallVector<Value *, 8> indices(op.getMapOperands());
|
||||
auto maybeExpandedMap =
|
||||
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
|
||||
if (!maybeExpandedMap)
|
||||
|
@ -425,7 +425,7 @@ public:
|
|||
virtual PatternMatchResult
|
||||
matchAndRewrite(AffineStoreOp op, PatternRewriter &rewriter) const override {
|
||||
// Expand affine map from 'affineStoreOp'.
|
||||
SmallVector<Value *, 8> indices(op.getIndices());
|
||||
SmallVector<Value *, 8> indices(op.getMapOperands());
|
||||
auto maybeExpandedMap =
|
||||
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
|
||||
if (!maybeExpandedMap)
|
||||
|
|
|
@ -814,14 +814,15 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv,
|
|||
// as needed by various targets.
|
||||
if (auto load = dyn_cast<AffineLoadOp>(opInst)) {
|
||||
OpBuilder b(opInst);
|
||||
SmallVector<Value *, 4> mapOperands(load.getIndices());
|
||||
SmallVector<Value *, 4> mapOperands(load.getMapOperands());
|
||||
SmallVector<Value *, 8> indices;
|
||||
indices.reserve(load.getMemRefType().getRank());
|
||||
if (load.getAffineMap() !=
|
||||
b.getMultiDimIdentityMap(load.getMemRefType().getRank())) {
|
||||
computeMemoryOpIndices(opInst, load.getAffineMap(), mapOperands, indices);
|
||||
} else {
|
||||
indices.append(load.getIndices().begin(), load.getIndices().end());
|
||||
indices.append(load.getMapOperands().begin(),
|
||||
load.getMapOperands().end());
|
||||
}
|
||||
auto permutationMap =
|
||||
makePermutationMap(opInst, indices, state->strategy->loopToVectorDim);
|
||||
|
@ -1038,7 +1039,7 @@ static Operation *vectorizeOneOperation(Operation *opInst,
|
|||
auto *value = store.getValueToStore();
|
||||
auto *vectorValue = vectorizeOperand(value, opInst, state);
|
||||
|
||||
SmallVector<Value *, 4> mapOperands(store.getIndices());
|
||||
SmallVector<Value *, 4> mapOperands(store.getMapOperands());
|
||||
SmallVector<Value *, 8> indices;
|
||||
indices.reserve(store.getMemRefType().getRank());
|
||||
if (store.getAffineMap() !=
|
||||
|
@ -1046,7 +1047,8 @@ static Operation *vectorizeOneOperation(Operation *opInst,
|
|||
computeMemoryOpIndices(opInst, store.getAffineMap(), mapOperands,
|
||||
indices);
|
||||
} else {
|
||||
indices.append(store.getIndices().begin(), store.getIndices().end());
|
||||
indices.append(store.getMapOperands().begin(),
|
||||
store.getMapOperands().end());
|
||||
}
|
||||
|
||||
auto permutationMap =
|
||||
|
|
|
@ -424,6 +424,7 @@ func @fold_empty_loop() {
|
|||
}
|
||||
return
|
||||
}
|
||||
// CHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -476,3 +477,29 @@ func @canonicalize_bounds(%M : index, %N : index) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Compose maps into affine load and store ops.
|
||||
|
||||
// CHECK-DAG: #map{{[0-9]+}} = (d0) -> (d0 + 1)
|
||||
|
||||
// CHECK-LABEL: @compose_into_affine_load_store
|
||||
func @compose_into_affine_load_store(%A : memref<1024xf32>, %u : index) {
|
||||
%cf1 = constant 1.0 : f32
|
||||
// CHECK: affine.for %[[IV:.*]] = 0 to 1024
|
||||
affine.for %i = 0 to 1024 {
|
||||
// Make sure the unused operand (%u below) gets dropped as well.
|
||||
%idx = affine.apply (d0, d1) -> (d0 + 1) (%i, %u)
|
||||
affine.load %A[%idx] : memref<1024xf32>
|
||||
affine.store %cf1, %A[%idx] : memref<1024xf32>
|
||||
// CHECK-NEXT: affine.load %{{.*}}[%[[IV]] + 1]
|
||||
// CHECK-NEXT: affine.store %cst, %{{.*}}[%[[IV]] + 1]
|
||||
|
||||
// Map remains the same, but operand changes on composition.
|
||||
%copy = affine.apply (d0) -> (d0) (%i)
|
||||
affine.load %A[%copy] : memref<1024xf32>
|
||||
// CHECK-NEXT: affine.load %{{.*}}[%[[IV]]]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -256,13 +256,12 @@ func @xor_self_tensor(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
|
|||
|
||||
// CHECK-LABEL: func @memref_cast_folding
|
||||
func @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 {
|
||||
// CHECK-NOT: memref_cast
|
||||
%1 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
|
||||
// CHECK-NEXT: %c0 = constant 0 : index
|
||||
%c0 = constant 0 : index
|
||||
// CHECK-NOT: dim
|
||||
%dim = dim %1, 0 : memref<? x f32>
|
||||
|
||||
// CHECK: affine.load %arg0[%c4 - 1]
|
||||
// CHECK-NEXT: affine.load %arg0[3]
|
||||
affine.load %1[%dim - 1] : memref<?xf32>
|
||||
|
||||
// CHECK-NEXT: store %arg1, %arg0[%c0] : memref<4xf32>
|
||||
|
|
Loading…
Reference in New Issue