From 5fbdb67b0aa7f01b17dcca62e08e3db38d021fce Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Thu, 7 Nov 2019 08:04:33 -0800 Subject: [PATCH] Add canonicalizer for ViewOp which folds constants into the ViewOp memref shape and layout map strides and offset. PiperOrigin-RevId: 279088023 --- mlir/include/mlir/Dialect/StandardOps/Ops.td | 3 +- mlir/lib/Dialect/StandardOps/Ops.cpp | 112 +++++++++++++++++++ mlir/test/Transforms/canonicalize.mlir | 46 ++++++++ 3 files changed, 160 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index be20c3823268..4dd22bab2d99 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1192,7 +1192,8 @@ def ViewOp : Std_Op<"view"> { operand_begin() + 1 + getType().getNumDynamicDims()}; } }]; - // TODO(andydavis) Add canonicalizer to fold constants into shape and map. + + let hasCanonicalizer = 1; } def XOrOp : IntArithmeticOp<"xor", [Commutative]> { diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 82d4324dff84..60002649a216 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2419,6 +2419,118 @@ static LogicalResult verify(ViewOp op) { return success(); } +namespace { + +struct ViewOpShapeFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ViewOp viewOp, + PatternRewriter &rewriter) const override { + // Return if none of the operands are constants. + if (llvm::none_of(viewOp.getOperands(), [](Value *operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return matchFailure(); + + // Get result memref type. + auto memrefType = viewOp.getType(); + if (memrefType.getAffineMaps().size() != 1) + return matchFailure(); + auto map = memrefType.getAffineMaps()[0]; + + // Fold any dynamic dim operands which are produced by a constant. + SmallVector newShapeConstants; + newShapeConstants.reserve(memrefType.getRank()); + SmallVector newOperands; + SmallVector droppedOperands; + + unsigned dynamicDimPos = 1; + unsigned rank = memrefType.getRank(); + for (unsigned dim = 0, e = rank; dim < e; ++dim) { + int64_t dimSize = memrefType.getDimSize(dim); + // If this is already static dimension, keep it. + if (!ShapedType::isDynamic(dimSize)) { + newShapeConstants.push_back(dimSize); + continue; + } + auto *defOp = viewOp.getOperand(dynamicDimPos)->getDefiningOp(); + if (auto constantIndexOp = dyn_cast_or_null(defOp)) { + // Dynamic shape dimension will be folded. + newShapeConstants.push_back(constantIndexOp.getValue()); + // Record to check for zero uses later below. + droppedOperands.push_back(constantIndexOp); + } else { + // Dynamic shape dimension not folded; copy operand from old memref. + newShapeConstants.push_back(dimSize); + newOperands.push_back(viewOp.getOperand(dynamicDimPos)); + } + dynamicDimPos++; + } + + // Get offset from old memref view type 'memRefType'. + int64_t oldOffset; + llvm::SmallVector oldStrides; + if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) + return matchFailure(); + + // Fold dynamic offset operand if it is produced by a constant. + auto *dynamicOffset = viewOp.getDynamicOffset(); + int64_t newOffset = oldOffset; + unsigned dynamicOffsetOperandCount = 0; + if (dynamicOffset != nullptr) { + auto *defOp = dynamicOffset->getDefiningOp(); + if (auto constantIndexOp = dyn_cast_or_null(defOp)) { + // Dynamic offset will be folded into the map. + newOffset = constantIndexOp.getValue(); + droppedOperands.push_back(dynamicOffset); + } else { + // Unable to fold dynamic offset. Add it to 'newOperands' list. + newOperands.push_back(dynamicOffset); + dynamicOffsetOperandCount = 1; + } + } + + // Compute new strides based on 'newShapeConstants'. + SmallVector newStrides(rank); + newStrides[rank - 1] = 1; + bool dynamicStrides = false; + for (int i = rank - 2; i >= 0; --i) { + if (ShapedType::isDynamic(newShapeConstants[i + 1])) + dynamicStrides = true; + if (dynamicStrides) + newStrides[i] = MemRefType::getDynamicStrideOrOffset(); + else + newStrides[i] = newShapeConstants[i + 1] * newStrides[i + 1]; + } + + // Regenerate strided layout map with 'newStrides' and 'newOffset'. + map = makeStridedLinearLayoutMap(newStrides, newOffset, + rewriter.getContext()); + + // Create new memref type with constant folded dims and/or offset/strides. + auto newMemRefType = + MemRefType::get(newShapeConstants, memrefType.getElementType(), {map}, + memrefType.getMemorySpace()); + assert(static_cast(newOperands.size()) == + dynamicOffsetOperandCount + newMemRefType.getNumDynamicDims()); + + // Create new ViewOp. + auto newShapeCastOp = rewriter.create( + viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), newOperands); + // Insert a cast so we have the same type as the old memref type. + rewriter.replaceOpWithNewOp(droppedOperands, viewOp, + newShapeCastOp, viewOp.getType()); + return matchSuccess(); + } +}; + +} // end anonymous namespace + +void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index b100f7213f19..8ccf24061b9b 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1,5 +1,15 @@ // RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s +#TEST_VIEW_MAP0 = (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1) +#TEST_VIEW_MAP1 = (d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2) +#TEST_VIEW_MAP2 = (d0, d1)[s0, s1] -> (d0 * 4 + d1 + s1) + +// CHECK-DAG: #[[VIEW_MAP0:map[0-9]+]] = (d0, d1) -> (d0 * 11 + d1 + 15) +// CHECK-DAG: #[[VIEW_MAP1:map[0-9]+]] = (d0, d1)[s0] -> (d0 * 11 + s0 + d1) +// CHECK-DAG: #[[VIEW_MAP2:map[0-9]+]] = (d0, d1)[s0] -> (d0 * s0 + d1 + 15) +// CHECK-DAG: #[[VIEW_MAP3:map[0-9]+]] = (d0, d1, d2)[s0] -> (d0 * s0 + d1 * 7 + d2) +// CHECK-DAG: #[[VIEW_MAP4:map[0-9]+]] = (d0, d1) -> (d0 * 4 + d1 + 15) + // CHECK-LABEL: func @test_subi_zero func @test_subi_zero(%arg0: i32) -> i32 { // CHECK-NEXT: %c0_i32 = constant 0 : i32 @@ -579,3 +589,39 @@ func @cast_values(%arg0: tensor<*xi32>, %arg1: memref) -> (tensor<2xi32>, // CHECK-NEXT: return %0, %1 : tensor<2xi32>, memref<2xi32> return %4, %5 : tensor<2xi32>, memref<2xi32> } + +// CHECK-LABEL: func @view +func @view(%arg0 : index) { + %0 = alloc() : memref<2048xi8> + %c7 = constant 7 : index + %c11 = constant 11 : index + %c15 = constant 15 : index + + // Test: fold constant sizes and offset, update map with static stride/offset. + // CHECK: std.view %0[][] : memref<2048xi8> to memref<7x11xf32, #[[VIEW_MAP0]]> + %1 = view %0[%c7, %c11][%c15] + : memref<2048xi8> to memref + // Test: fold constant sizes but not offset, update map with static stride. + // Test that we do not a fold dynamic dim which is not produced by a constant. + // CHECK: std.view %0[][%arg0] : memref<2048xi8> to memref<7x11xf32, #[[VIEW_MAP1]]> + %2 = view %0[%c7, %c11][%arg0] + : memref<2048xi8> to memref + // Test: fold constant offset but not sizes, update map with constant offset. + // Test that we fold constant offset but not dynamic dims. + // CHECK: std.view %0[%arg0, %arg0][] : memref<2048xi8> to memref + %3 = view %0[%arg0, %arg0][%c15] + : memref<2048xi8> to memref + // Test: fold one constant dim, no offset, should update with constant + // stride on dim 1, but leave dynamic stride on dim 0. + // CHECK: std.view %0[%arg0, %arg0][] : memref<2048xi8> to memref + %4 = view %0[%arg0, %arg0, %c7][] + : memref<2048xi8> to memref + + // Test: preserve an existing static dim size while folding a dynamic + // dimension and offset. + // CHECK: std.view %0[][] : memref<2048xi8> to memref<7x4xf32, #[[VIEW_MAP4]]> + %5 = view %0[%c7][%c15] + : memref<2048xi8> to memref + + return +}