Revert "[mlir] Fold Arithmetic::ConstantOp and Tensor::ExtractSliceOp."

This reverts commit 3104994104.
This commit is contained in:
Okwan Kwon 2022-02-28 19:11:29 +00:00
parent 278b407a30
commit 4f5eb53e68
4 changed files with 4 additions and 207 deletions

View File

@ -9,7 +9,6 @@
#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
@ -21,19 +20,6 @@ namespace tensor {
void populateSplitPaddingPatterns(RewritePatternSet &patterns,
PatternBenefit baseBenefit = 1);
/// Function to control the folding of constant and extract slice
using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
/// Patterns to fold the extract slice op with its constant operand
void populateFoldConstantExtractSlicePatterns(
RewritePatternSet &patterns,
const ControlConstantExtractSliceFusionFn &controlFn =
[](ExtractSliceOp op) {
// Disable by default because the folding can generate a large
// constant tensor, which would affect the compile time and storage.
return false;
});
} // namespace tensor
} // namespace mlir

View File

@ -6,14 +6,17 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
@ -1155,134 +1158,8 @@ public:
return success();
}
};
/// Slice elements from `values` into `outValues`. `counts` represents the
/// numbers of elements to stride in the original values for each dimension.
/// The output values can be used to construct a DenseElementsAttr.
template <typename IterTy, typename ElemTy>
static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides,
llvm::SmallVectorImpl<ElemTy> *outValues) {
assert(offsets.size() == sizes.size());
assert(offsets.size() == strides.size());
if (offsets.empty())
return;
int64_t offset = offsets.front();
int64_t size = sizes.front();
int64_t stride = strides.front();
if (offsets.size() == 1) {
for (int64_t i = 0; i < size; ++i, offset += stride)
outValues->push_back(*(values + offset));
return;
}
for (int64_t i = 0; i < size; ++i, offset += stride) {
auto begin = values + offset * counts.front();
sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
offsets.drop_front(), sizes.drop_front(),
strides.drop_front(), outValues);
}
}
/// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
/// operation might introduce more constant data; Users can control their
/// heuristics by the control function.
class ConstantOpExtractSliceFolder final
: public OpRewritePattern<ExtractSliceOp> {
public:
using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
ConstantOpExtractSliceFolder(MLIRContext *context,
ControlConstantExtractSliceFusionFn controlFn)
: OpRewritePattern<ExtractSliceOp>(context),
controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(ExtractSliceOp op,
PatternRewriter &rewriter) const override {
DenseElementsAttr attr;
if (!matchPattern(op.source(), m_Constant(&attr)))
return failure();
// A constant splat is handled by fold().
if (attr.isSplat())
return failure();
// Dynamic result shape is not supported.
auto sourceType = op.source().getType().cast<ShapedType>();
auto resultType = op.result().getType().cast<ShapedType>();
if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
// Customized control over the folding.
if (!controlFn(op))
return failure();
int64_t count = sourceType.getNumElements();
if (count == 0)
return failure();
// Check if there are any dynamic parts, which are not supported.
auto offsets = extractFromI64ArrayAttr(op.static_offsets());
if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
return failure();
auto sizes = extractFromI64ArrayAttr(op.static_sizes());
if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
return failure();
auto strides = extractFromI64ArrayAttr(op.static_strides());
if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
return failure();
// Compute the stride for each dimension.
SmallVector<int64_t> counts;
ArrayRef<int64_t> shape = sourceType.getShape();
counts.reserve(shape.size());
for (int64_t v : shape) {
count = count / v;
counts.push_back(count);
}
// New attribute constructed by the sliced values.
DenseElementsAttr newAttr;
if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
SmallVector<APInt> outValues;
outValues.reserve(sourceType.getNumElements());
sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
elems.begin(), counts, offsets, sizes, strides, &outValues);
newAttr = DenseElementsAttr::get(resultType, outValues);
} else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
SmallVector<APFloat> outValues;
outValues.reserve(sourceType.getNumElements());
sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
elems.begin(), counts, offsets, sizes, strides, &outValues);
newAttr = DenseElementsAttr::get(resultType, outValues);
}
if (newAttr) {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
return success();
}
return failure();
}
private:
/// This additionally controls whether the fold happens or not. Users can
/// impose their heuristics in the function.
ControlConstantExtractSliceFusionFn controlFn;
};
} // namespace
void mlir::tensor::populateFoldConstantExtractSlicePatterns(
RewritePatternSet &patterns,
const ControlConstantExtractSliceFusionFn &controlFn) {
patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
}
/// Return the canonical type of the result of an extract_slice op.
struct SliceReturnTypeCanonicalizer {
RankedTensorType operator()(ExtractSliceOp op,
@ -1361,7 +1238,6 @@ OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
return this->source();
if (Value slice = foldExtractAfterInsertSlice(*this))
return slice;
return OpFoldResult();
}

View File

@ -1,39 +0,0 @@
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-constant-extract-slice %s | FileCheck %s
// CHECK-LABEL: func @slice_constant
// CHECK-NOT: tensor.extract_slice
// CHECK: %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32>
// CHECK: return %[[CONST]] : tensor<1x1xf32>
func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32>
{
%cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32>
%slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32>
return %slice : tensor<1x1xf32>
}
// -----
// CHECK-LABEL: func @slice_constant_3x4
// CHECK-NOT: tensor.extract_slice
// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.000000e+01, 9.000000e+00], [1.100000e+01, 1.200000e+01]]> : tensor<2x2xf32>
// CHECK: return %[[CONST]] : tensor<2x2xf32>
func @slice_constant_3x4(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
{
%cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
%slice = tensor.extract_slice %cst[0, 0] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
return %slice : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @slice_constant_3x4_offsets
// CHECK-NOT: tensor.extract_slice
// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.200000e+01, 1.300000e+01], [3.000000e+00, 5.000000e+00]]> : tensor<2x2xf32>
// CHECK: return %[[CONST]] : tensor<2x2xf32>
func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
{
%cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
%slice = tensor.extract_slice %cst[1, 1] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
return %slice : tensor<2x2xf32>
}

View File

@ -41,11 +41,6 @@ struct TestTensorTransforms
*this, "test-split-padding-patterns",
llvm::cl::desc("Test patterns to split tensor.pad ops"),
llvm::cl::init(false)};
Option<bool> testFoldConstantExtractSlice{
*this, "test-fold-constant-extract-slice",
llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
llvm::cl::init(false)};
};
} // namespace
@ -55,31 +50,10 @@ static void applySplitPaddingPatterns(FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static void applyFoldConstantExtractSlicePatterns(FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
tensor::ControlConstantExtractSliceFusionFn controlFn =
[](tensor::ExtractSliceOp op) {
if (!op.source().hasOneUse())
return false;
auto resultType = op.result().getType().cast<ShapedType>();
constexpr int64_t kConstantFoldingMaxNumElements = 1024;
if (resultType.getNumElements() > kConstantFoldingMaxNumElements)
return false;
return true;
};
tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
void TestTensorTransforms::runOnOperation() {
FuncOp func = getOperation();
if (testSplitPaddingPatterns)
applySplitPaddingPatterns(func);
if (testFoldConstantExtractSlice)
applyFoldConstantExtractSlicePatterns(func);
}
namespace mlir {