From 4f5eb53e68b1da47a211a97bd2fe4ea26b590e58 Mon Sep 17 00:00:00 2001 From: Okwan Kwon Date: Mon, 28 Feb 2022 19:11:29 +0000 Subject: [PATCH] Revert "[mlir] Fold Arithmetic::ConstantOp and Tensor::ExtractSliceOp." This reverts commit 3104994104f0c2f274acf5e01eb6cc82e9cca06b. --- .../Dialect/Tensor/Transforms/Transforms.h | 14 -- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 132 +----------------- .../Tensor/fold-constant-extract-slice.mlir | 39 ------ .../Dialect/Tensor/TestTensorTransforms.cpp | 26 ---- 4 files changed, 4 insertions(+), 207 deletions(-) delete mode 100644 mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index a4f78750368d..e6267e9cf02e 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -9,7 +9,6 @@ #ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H #define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -21,19 +20,6 @@ namespace tensor { void populateSplitPaddingPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit = 1); -/// Function to control the folding of constant and extract slice -using ControlConstantExtractSliceFusionFn = std::function; - -/// Patterns to fold the extract slice op with its constant operand -void populateFoldConstantExtractSlicePatterns( - RewritePatternSet &patterns, - const ControlConstantExtractSliceFusionFn &controlFn = - [](ExtractSliceOp op) { - // Disable by default because the folding can generate a large - // constant tensor, which would affect the compile time and storage. - return false; - }); - } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index d6a4bb460a7f..70aa7b5fe57f 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -6,14 +6,17 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" -#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -1155,134 +1158,8 @@ public: return success(); } }; - -/// Slice elements from `values` into `outValues`. `counts` represents the -/// numbers of elements to stride in the original values for each dimension. -/// The output values can be used to construct a DenseElementsAttr. -template -static void sliceElements(IterTy values, ArrayRef counts, - ArrayRef offsets, ArrayRef sizes, - ArrayRef strides, - llvm::SmallVectorImpl *outValues) { - assert(offsets.size() == sizes.size()); - assert(offsets.size() == strides.size()); - if (offsets.empty()) - return; - - int64_t offset = offsets.front(); - int64_t size = sizes.front(); - int64_t stride = strides.front(); - if (offsets.size() == 1) { - for (int64_t i = 0; i < size; ++i, offset += stride) - outValues->push_back(*(values + offset)); - - return; - } - - for (int64_t i = 0; i < size; ++i, offset += stride) { - auto begin = values + offset * counts.front(); - sliceElements(begin, counts.drop_front(), - offsets.drop_front(), sizes.drop_front(), - strides.drop_front(), outValues); - } -} - -/// Fold arith.constant and tensor.extract_slice into arith.constant. The folded -/// operation might introduce more constant data; Users can control their -/// heuristics by the control function. -class ConstantOpExtractSliceFolder final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - ConstantOpExtractSliceFolder(MLIRContext *context, - ControlConstantExtractSliceFusionFn controlFn) - : OpRewritePattern(context), - controlFn(std::move(controlFn)) {} - - LogicalResult matchAndRewrite(ExtractSliceOp op, - PatternRewriter &rewriter) const override { - DenseElementsAttr attr; - if (!matchPattern(op.source(), m_Constant(&attr))) - return failure(); - - // A constant splat is handled by fold(). - if (attr.isSplat()) - return failure(); - - // Dynamic result shape is not supported. - auto sourceType = op.source().getType().cast(); - auto resultType = op.result().getType().cast(); - if (!sourceType.hasStaticShape() || !resultType.hasStaticShape()) - return failure(); - - // Customized control over the folding. - if (!controlFn(op)) - return failure(); - - int64_t count = sourceType.getNumElements(); - if (count == 0) - return failure(); - - // Check if there are any dynamic parts, which are not supported. - auto offsets = extractFromI64ArrayAttr(op.static_offsets()); - if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset)) - return failure(); - auto sizes = extractFromI64ArrayAttr(op.static_sizes()); - if (llvm::is_contained(sizes, ShapedType::kDynamicSize)) - return failure(); - auto strides = extractFromI64ArrayAttr(op.static_strides()); - if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset)) - return failure(); - - // Compute the stride for each dimension. - SmallVector counts; - ArrayRef shape = sourceType.getShape(); - counts.reserve(shape.size()); - for (int64_t v : shape) { - count = count / v; - counts.push_back(count); - } - - // New attribute constructed by the sliced values. - DenseElementsAttr newAttr; - - if (auto elems = attr.dyn_cast()) { - SmallVector outValues; - outValues.reserve(sourceType.getNumElements()); - sliceElements( - elems.begin(), counts, offsets, sizes, strides, &outValues); - newAttr = DenseElementsAttr::get(resultType, outValues); - } else if (auto elems = attr.dyn_cast()) { - SmallVector outValues; - outValues.reserve(sourceType.getNumElements()); - sliceElements( - elems.begin(), counts, offsets, sizes, strides, &outValues); - newAttr = DenseElementsAttr::get(resultType, outValues); - } - - if (newAttr) { - rewriter.replaceOpWithNewOp(op, resultType, newAttr); - return success(); - } - - return failure(); - } - -private: - /// This additionally controls whether the fold happens or not. Users can - /// impose their heuristics in the function. - ControlConstantExtractSliceFusionFn controlFn; -}; - } // namespace -void mlir::tensor::populateFoldConstantExtractSlicePatterns( - RewritePatternSet &patterns, - const ControlConstantExtractSliceFusionFn &controlFn) { - patterns.add(patterns.getContext(), controlFn); -} - /// Return the canonical type of the result of an extract_slice op. struct SliceReturnTypeCanonicalizer { RankedTensorType operator()(ExtractSliceOp op, @@ -1361,7 +1238,6 @@ OpFoldResult ExtractSliceOp::fold(ArrayRef operands) { return this->source(); if (Value slice = foldExtractAfterInsertSlice(*this)) return slice; - return OpFoldResult(); } diff --git a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir deleted file mode 100644 index 03c6195d4037..000000000000 --- a/mlir/test/Dialect/Tensor/fold-constant-extract-slice.mlir +++ /dev/null @@ -1,39 +0,0 @@ -// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-constant-extract-slice %s | FileCheck %s - -// CHECK-LABEL: func @slice_constant -// CHECK-NOT: tensor.extract_slice -// CHECK: %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32> -// CHECK: return %[[CONST]] : tensor<1x1xf32> -func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32> -{ - %cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32> - %slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32> - return %slice : tensor<1x1xf32> -} - -// ----- - -// CHECK-LABEL: func @slice_constant_3x4 -// CHECK-NOT: tensor.extract_slice -// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.000000e+01, 9.000000e+00], [1.100000e+01, 1.200000e+01]]> : tensor<2x2xf32> -// CHECK: return %[[CONST]] : tensor<2x2xf32> -func @slice_constant_3x4(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32> -{ - %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32> - %slice = tensor.extract_slice %cst[0, 0] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32> - return %slice : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @slice_constant_3x4_offsets -// CHECK-NOT: tensor.extract_slice -// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.200000e+01, 1.300000e+01], [3.000000e+00, 5.000000e+00]]> : tensor<2x2xf32> -// CHECK: return %[[CONST]] : tensor<2x2xf32> -func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32> -{ - %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32> - %slice = tensor.extract_slice %cst[1, 1] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32> - return %slice : tensor<2x2xf32> -} - diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index 4d947ef3ee53..c720ca1e3a23 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -41,11 +41,6 @@ struct TestTensorTransforms *this, "test-split-padding-patterns", llvm::cl::desc("Test patterns to split tensor.pad ops"), llvm::cl::init(false)}; - - Option testFoldConstantExtractSlice{ - *this, "test-fold-constant-extract-slice", - llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"), - llvm::cl::init(false)}; }; } // namespace @@ -55,31 +50,10 @@ static void applySplitPaddingPatterns(FuncOp funcOp) { (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } -static void applyFoldConstantExtractSlicePatterns(FuncOp funcOp) { - RewritePatternSet patterns(funcOp.getContext()); - tensor::ControlConstantExtractSliceFusionFn controlFn = - [](tensor::ExtractSliceOp op) { - if (!op.source().hasOneUse()) - return false; - - auto resultType = op.result().getType().cast(); - constexpr int64_t kConstantFoldingMaxNumElements = 1024; - if (resultType.getNumElements() > kConstantFoldingMaxNumElements) - return false; - - return true; - }; - - tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); -} - void TestTensorTransforms::runOnOperation() { FuncOp func = getOperation(); if (testSplitPaddingPatterns) applySplitPaddingPatterns(func); - if (testFoldConstantExtractSlice) - applyFoldConstantExtractSlicePatterns(func); } namespace mlir {