forked from OSchip/llvm-project
Add canonicalizer for ViewOp which folds constants into the ViewOp memref shape and layout map strides and offset.
PiperOrigin-RevId: 279088023
This commit is contained in:
parent
a10d836c6d
commit
5fbdb67b0a
|
@ -1192,7 +1192,8 @@ def ViewOp : Std_Op<"view"> {
|
|||
operand_begin() + 1 + getType().getNumDynamicDims()};
|
||||
}
|
||||
}];
|
||||
// TODO(andydavis) Add canonicalizer to fold constants into shape and map.
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
|
||||
|
|
|
@ -2419,6 +2419,118 @@ static LogicalResult verify(ViewOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
||||
using OpRewritePattern<ViewOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(ViewOp viewOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Return if none of the operands are constants.
|
||||
if (llvm::none_of(viewOp.getOperands(), [](Value *operand) {
|
||||
return matchPattern(operand, m_ConstantIndex());
|
||||
}))
|
||||
return matchFailure();
|
||||
|
||||
// Get result memref type.
|
||||
auto memrefType = viewOp.getType();
|
||||
if (memrefType.getAffineMaps().size() != 1)
|
||||
return matchFailure();
|
||||
auto map = memrefType.getAffineMaps()[0];
|
||||
|
||||
// Fold any dynamic dim operands which are produced by a constant.
|
||||
SmallVector<int64_t, 4> newShapeConstants;
|
||||
newShapeConstants.reserve(memrefType.getRank());
|
||||
SmallVector<Value *, 4> newOperands;
|
||||
SmallVector<Value *, 4> droppedOperands;
|
||||
|
||||
unsigned dynamicDimPos = 1;
|
||||
unsigned rank = memrefType.getRank();
|
||||
for (unsigned dim = 0, e = rank; dim < e; ++dim) {
|
||||
int64_t dimSize = memrefType.getDimSize(dim);
|
||||
// If this is already static dimension, keep it.
|
||||
if (!ShapedType::isDynamic(dimSize)) {
|
||||
newShapeConstants.push_back(dimSize);
|
||||
continue;
|
||||
}
|
||||
auto *defOp = viewOp.getOperand(dynamicDimPos)->getDefiningOp();
|
||||
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
|
||||
// Dynamic shape dimension will be folded.
|
||||
newShapeConstants.push_back(constantIndexOp.getValue());
|
||||
// Record to check for zero uses later below.
|
||||
droppedOperands.push_back(constantIndexOp);
|
||||
} else {
|
||||
// Dynamic shape dimension not folded; copy operand from old memref.
|
||||
newShapeConstants.push_back(dimSize);
|
||||
newOperands.push_back(viewOp.getOperand(dynamicDimPos));
|
||||
}
|
||||
dynamicDimPos++;
|
||||
}
|
||||
|
||||
// Get offset from old memref view type 'memRefType'.
|
||||
int64_t oldOffset;
|
||||
llvm::SmallVector<int64_t, 4> oldStrides;
|
||||
if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
|
||||
return matchFailure();
|
||||
|
||||
// Fold dynamic offset operand if it is produced by a constant.
|
||||
auto *dynamicOffset = viewOp.getDynamicOffset();
|
||||
int64_t newOffset = oldOffset;
|
||||
unsigned dynamicOffsetOperandCount = 0;
|
||||
if (dynamicOffset != nullptr) {
|
||||
auto *defOp = dynamicOffset->getDefiningOp();
|
||||
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
|
||||
// Dynamic offset will be folded into the map.
|
||||
newOffset = constantIndexOp.getValue();
|
||||
droppedOperands.push_back(dynamicOffset);
|
||||
} else {
|
||||
// Unable to fold dynamic offset. Add it to 'newOperands' list.
|
||||
newOperands.push_back(dynamicOffset);
|
||||
dynamicOffsetOperandCount = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute new strides based on 'newShapeConstants'.
|
||||
SmallVector<int64_t, 4> newStrides(rank);
|
||||
newStrides[rank - 1] = 1;
|
||||
bool dynamicStrides = false;
|
||||
for (int i = rank - 2; i >= 0; --i) {
|
||||
if (ShapedType::isDynamic(newShapeConstants[i + 1]))
|
||||
dynamicStrides = true;
|
||||
if (dynamicStrides)
|
||||
newStrides[i] = MemRefType::getDynamicStrideOrOffset();
|
||||
else
|
||||
newStrides[i] = newShapeConstants[i + 1] * newStrides[i + 1];
|
||||
}
|
||||
|
||||
// Regenerate strided layout map with 'newStrides' and 'newOffset'.
|
||||
map = makeStridedLinearLayoutMap(newStrides, newOffset,
|
||||
rewriter.getContext());
|
||||
|
||||
// Create new memref type with constant folded dims and/or offset/strides.
|
||||
auto newMemRefType =
|
||||
MemRefType::get(newShapeConstants, memrefType.getElementType(), {map},
|
||||
memrefType.getMemorySpace());
|
||||
assert(static_cast<int64_t>(newOperands.size()) ==
|
||||
dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims());
|
||||
|
||||
// Create new ViewOp.
|
||||
auto newShapeCastOp = rewriter.create<ViewOp>(
|
||||
viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), newOperands);
|
||||
// Insert a cast so we have the same type as the old memref type.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(droppedOperands, viewOp,
|
||||
newShapeCastOp, viewOp.getType());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ViewOpShapeFolder>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ZeroExtendIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,5 +1,15 @@
|
|||
// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
|
||||
#TEST_VIEW_MAP0 = (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)
|
||||
#TEST_VIEW_MAP1 = (d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)
|
||||
#TEST_VIEW_MAP2 = (d0, d1)[s0, s1] -> (d0 * 4 + d1 + s1)
|
||||
|
||||
// CHECK-DAG: #[[VIEW_MAP0:map[0-9]+]] = (d0, d1) -> (d0 * 11 + d1 + 15)
|
||||
// CHECK-DAG: #[[VIEW_MAP1:map[0-9]+]] = (d0, d1)[s0] -> (d0 * 11 + s0 + d1)
|
||||
// CHECK-DAG: #[[VIEW_MAP2:map[0-9]+]] = (d0, d1)[s0] -> (d0 * s0 + d1 + 15)
|
||||
// CHECK-DAG: #[[VIEW_MAP3:map[0-9]+]] = (d0, d1, d2)[s0] -> (d0 * s0 + d1 * 7 + d2)
|
||||
// CHECK-DAG: #[[VIEW_MAP4:map[0-9]+]] = (d0, d1) -> (d0 * 4 + d1 + 15)
|
||||
|
||||
// CHECK-LABEL: func @test_subi_zero
|
||||
func @test_subi_zero(%arg0: i32) -> i32 {
|
||||
// CHECK-NEXT: %c0_i32 = constant 0 : i32
|
||||
|
@ -579,3 +589,39 @@ func @cast_values(%arg0: tensor<*xi32>, %arg1: memref<?xi32>) -> (tensor<2xi32>,
|
|||
// CHECK-NEXT: return %0, %1 : tensor<2xi32>, memref<2xi32>
|
||||
return %4, %5 : tensor<2xi32>, memref<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @view
|
||||
func @view(%arg0 : index) {
|
||||
%0 = alloc() : memref<2048xi8>
|
||||
%c7 = constant 7 : index
|
||||
%c11 = constant 11 : index
|
||||
%c15 = constant 15 : index
|
||||
|
||||
// Test: fold constant sizes and offset, update map with static stride/offset.
|
||||
// CHECK: std.view %0[][] : memref<2048xi8> to memref<7x11xf32, #[[VIEW_MAP0]]>
|
||||
%1 = view %0[%c7, %c11][%c15]
|
||||
: memref<2048xi8> to memref<?x?xf32, #TEST_VIEW_MAP0>
|
||||
// Test: fold constant sizes but not offset, update map with static stride.
|
||||
// Test that we do not a fold dynamic dim which is not produced by a constant.
|
||||
// CHECK: std.view %0[][%arg0] : memref<2048xi8> to memref<7x11xf32, #[[VIEW_MAP1]]>
|
||||
%2 = view %0[%c7, %c11][%arg0]
|
||||
: memref<2048xi8> to memref<?x?xf32, #TEST_VIEW_MAP0>
|
||||
// Test: fold constant offset but not sizes, update map with constant offset.
|
||||
// Test that we fold constant offset but not dynamic dims.
|
||||
// CHECK: std.view %0[%arg0, %arg0][] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP2]]>
|
||||
%3 = view %0[%arg0, %arg0][%c15]
|
||||
: memref<2048xi8> to memref<?x?xf32, #TEST_VIEW_MAP0>
|
||||
// Test: fold one constant dim, no offset, should update with constant
|
||||
// stride on dim 1, but leave dynamic stride on dim 0.
|
||||
// CHECK: std.view %0[%arg0, %arg0][] : memref<2048xi8> to memref<?x?x7xf32, #[[VIEW_MAP3]]>
|
||||
%4 = view %0[%arg0, %arg0, %c7][]
|
||||
: memref<2048xi8> to memref<?x?x?xf32, #TEST_VIEW_MAP1>
|
||||
|
||||
// Test: preserve an existing static dim size while folding a dynamic
|
||||
// dimension and offset.
|
||||
// CHECK: std.view %0[][] : memref<2048xi8> to memref<7x4xf32, #[[VIEW_MAP4]]>
|
||||
%5 = view %0[%c7][%c15]
|
||||
: memref<2048xi8> to memref<?x4xf32, #TEST_VIEW_MAP2>
|
||||
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue