forked from OSchip/llvm-project
Revert "[mlir] Fold Arithmetic::ConstantOp and Tensor::ExtractSliceOp."
This reverts commit 3104994104
.
This commit is contained in:
parent
278b407a30
commit
4f5eb53e68
|
@ -9,7 +9,6 @@
|
|||
#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
|
||||
#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -21,19 +20,6 @@ namespace tensor {
|
|||
void populateSplitPaddingPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit baseBenefit = 1);
|
||||
|
||||
/// Function to control the folding of constant and extract slice
|
||||
using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
|
||||
|
||||
/// Patterns to fold the extract slice op with its constant operand
|
||||
void populateFoldConstantExtractSlicePatterns(
|
||||
RewritePatternSet &patterns,
|
||||
const ControlConstantExtractSliceFusionFn &controlFn =
|
||||
[](ExtractSliceOp op) {
|
||||
// Disable by default because the folding can generate a large
|
||||
// constant tensor, which would affect the compile time and storage.
|
||||
return false;
|
||||
});
|
||||
|
||||
} // namespace tensor
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -6,14 +6,17 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
|
@ -1155,134 +1158,8 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Slice elements from `values` into `outValues`. `counts` represents the
|
||||
/// numbers of elements to stride in the original values for each dimension.
|
||||
/// The output values can be used to construct a DenseElementsAttr.
|
||||
template <typename IterTy, typename ElemTy>
|
||||
static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
|
||||
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> strides,
|
||||
llvm::SmallVectorImpl<ElemTy> *outValues) {
|
||||
assert(offsets.size() == sizes.size());
|
||||
assert(offsets.size() == strides.size());
|
||||
if (offsets.empty())
|
||||
return;
|
||||
|
||||
int64_t offset = offsets.front();
|
||||
int64_t size = sizes.front();
|
||||
int64_t stride = strides.front();
|
||||
if (offsets.size() == 1) {
|
||||
for (int64_t i = 0; i < size; ++i, offset += stride)
|
||||
outValues->push_back(*(values + offset));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < size; ++i, offset += stride) {
|
||||
auto begin = values + offset * counts.front();
|
||||
sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
|
||||
offsets.drop_front(), sizes.drop_front(),
|
||||
strides.drop_front(), outValues);
|
||||
}
|
||||
}
|
||||
|
||||
/// Fold arith.constant and tensor.extract_slice into arith.constant. The folded
|
||||
/// operation might introduce more constant data; Users can control their
|
||||
/// heuristics by the control function.
|
||||
class ConstantOpExtractSliceFolder final
|
||||
: public OpRewritePattern<ExtractSliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
|
||||
|
||||
ConstantOpExtractSliceFolder(MLIRContext *context,
|
||||
ControlConstantExtractSliceFusionFn controlFn)
|
||||
: OpRewritePattern<ExtractSliceOp>(context),
|
||||
controlFn(std::move(controlFn)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
DenseElementsAttr attr;
|
||||
if (!matchPattern(op.source(), m_Constant(&attr)))
|
||||
return failure();
|
||||
|
||||
// A constant splat is handled by fold().
|
||||
if (attr.isSplat())
|
||||
return failure();
|
||||
|
||||
// Dynamic result shape is not supported.
|
||||
auto sourceType = op.source().getType().cast<ShapedType>();
|
||||
auto resultType = op.result().getType().cast<ShapedType>();
|
||||
if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// Customized control over the folding.
|
||||
if (!controlFn(op))
|
||||
return failure();
|
||||
|
||||
int64_t count = sourceType.getNumElements();
|
||||
if (count == 0)
|
||||
return failure();
|
||||
|
||||
// Check if there are any dynamic parts, which are not supported.
|
||||
auto offsets = extractFromI64ArrayAttr(op.static_offsets());
|
||||
if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
|
||||
return failure();
|
||||
auto sizes = extractFromI64ArrayAttr(op.static_sizes());
|
||||
if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
|
||||
return failure();
|
||||
auto strides = extractFromI64ArrayAttr(op.static_strides());
|
||||
if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
|
||||
return failure();
|
||||
|
||||
// Compute the stride for each dimension.
|
||||
SmallVector<int64_t> counts;
|
||||
ArrayRef<int64_t> shape = sourceType.getShape();
|
||||
counts.reserve(shape.size());
|
||||
for (int64_t v : shape) {
|
||||
count = count / v;
|
||||
counts.push_back(count);
|
||||
}
|
||||
|
||||
// New attribute constructed by the sliced values.
|
||||
DenseElementsAttr newAttr;
|
||||
|
||||
if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
|
||||
SmallVector<APInt> outValues;
|
||||
outValues.reserve(sourceType.getNumElements());
|
||||
sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
|
||||
elems.begin(), counts, offsets, sizes, strides, &outValues);
|
||||
newAttr = DenseElementsAttr::get(resultType, outValues);
|
||||
} else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
SmallVector<APFloat> outValues;
|
||||
outValues.reserve(sourceType.getNumElements());
|
||||
sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
|
||||
elems.begin(), counts, offsets, sizes, strides, &outValues);
|
||||
newAttr = DenseElementsAttr::get(resultType, outValues);
|
||||
}
|
||||
|
||||
if (newAttr) {
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
/// This additionally controls whether the fold happens or not. Users can
|
||||
/// impose their heuristics in the function.
|
||||
ControlConstantExtractSliceFusionFn controlFn;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::tensor::populateFoldConstantExtractSlicePatterns(
|
||||
RewritePatternSet &patterns,
|
||||
const ControlConstantExtractSliceFusionFn &controlFn) {
|
||||
patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
|
||||
}
|
||||
|
||||
/// Return the canonical type of the result of an extract_slice op.
|
||||
struct SliceReturnTypeCanonicalizer {
|
||||
RankedTensorType operator()(ExtractSliceOp op,
|
||||
|
@ -1361,7 +1238,6 @@ OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
|
|||
return this->source();
|
||||
if (Value slice = foldExtractAfterInsertSlice(*this))
|
||||
return slice;
|
||||
|
||||
return OpFoldResult();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-constant-extract-slice %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @slice_constant
|
||||
// CHECK-NOT: tensor.extract_slice
|
||||
// CHECK: %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32>
|
||||
// CHECK: return %[[CONST]] : tensor<1x1xf32>
|
||||
func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32>
|
||||
{
|
||||
%cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32>
|
||||
%slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32>
|
||||
return %slice : tensor<1x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @slice_constant_3x4
|
||||
// CHECK-NOT: tensor.extract_slice
|
||||
// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.000000e+01, 9.000000e+00], [1.100000e+01, 1.200000e+01]]> : tensor<2x2xf32>
|
||||
// CHECK: return %[[CONST]] : tensor<2x2xf32>
|
||||
func @slice_constant_3x4(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
|
||||
{
|
||||
%cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
|
||||
%slice = tensor.extract_slice %cst[0, 0] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
|
||||
return %slice : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @slice_constant_3x4_offsets
|
||||
// CHECK-NOT: tensor.extract_slice
|
||||
// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.200000e+01, 1.300000e+01], [3.000000e+00, 5.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: return %[[CONST]] : tensor<2x2xf32>
|
||||
func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32>
|
||||
{
|
||||
%cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32>
|
||||
%slice = tensor.extract_slice %cst[1, 1] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32>
|
||||
return %slice : tensor<2x2xf32>
|
||||
}
|
||||
|
|
@ -41,11 +41,6 @@ struct TestTensorTransforms
|
|||
*this, "test-split-padding-patterns",
|
||||
llvm::cl::desc("Test patterns to split tensor.pad ops"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
Option<bool> testFoldConstantExtractSlice{
|
||||
*this, "test-fold-constant-extract-slice",
|
||||
llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -55,31 +50,10 @@ static void applySplitPaddingPatterns(FuncOp funcOp) {
|
|||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyFoldConstantExtractSlicePatterns(FuncOp funcOp) {
|
||||
RewritePatternSet patterns(funcOp.getContext());
|
||||
tensor::ControlConstantExtractSliceFusionFn controlFn =
|
||||
[](tensor::ExtractSliceOp op) {
|
||||
if (!op.source().hasOneUse())
|
||||
return false;
|
||||
|
||||
auto resultType = op.result().getType().cast<ShapedType>();
|
||||
constexpr int64_t kConstantFoldingMaxNumElements = 1024;
|
||||
if (resultType.getNumElements() > kConstantFoldingMaxNumElements)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
void TestTensorTransforms::runOnOperation() {
|
||||
FuncOp func = getOperation();
|
||||
if (testSplitPaddingPatterns)
|
||||
applySplitPaddingPatterns(func);
|
||||
if (testFoldConstantExtractSlice)
|
||||
applyFoldConstantExtractSlicePatterns(func);
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
|
Loading…
Reference in New Issue