forked from OSchip/llvm-project
[mlir][Linalg] Canonicalize tensor_reshape(splat-constant) -> splat-constant.
When the operand to the linalg.tensor_reshape op is a splat constant, the result can be replaced with a splat constant of the same value but different type. Differential Revision: https://reviews.llvm.org/D86117
This commit is contained in:
parent
87122c3480
commit
a65a50540e
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -734,9 +735,28 @@ static LogicalResult verify(TensorReshapeOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Reshape of a splat constant can be replaced with a constant of the result
|
||||
/// type.
|
||||
struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
|
||||
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
DenseElementsAttr attr;
|
||||
if (!matchPattern(reshapeOp.src(), m_Constant(&attr)))
|
||||
return failure();
|
||||
if (!attr || !attr.isSplat())
|
||||
return failure();
|
||||
DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
|
||||
reshapeOp.getResultType(), attr.getRawData(), true);
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshapeOp, newAttr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void TensorReshapeOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<CollapseReshapeOps<TensorReshapeOp>>(context);
|
||||
results.insert<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -203,3 +203,60 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf
|
|||
// CHECK-NOT: linalg.copy
|
||||
// CHECK-NEXT: linalg.generic
|
||||
|
||||
// -----
|
||||
|
||||
func @reshape_splat_constant_int32() -> tensor<2x4x2xi32>
|
||||
{
|
||||
%c0 = constant dense<42> : tensor<2x8xi32>
|
||||
%0 = linalg.tensor_reshape %c0
|
||||
[affine_map<(d0, d1, d2) -> (d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>]
|
||||
: tensor<2x8xi32> into tensor<2x4x2xi32>
|
||||
return %0 : tensor<2x4x2xi32>
|
||||
}
|
||||
// CHECK-LABEL: @reshape_splat_constant_int32
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi32>
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
// CHECK: return %[[CST]]
|
||||
|
||||
func @reshape_splat_constant_int16() -> tensor<2x4x2xi16>
|
||||
{
|
||||
%c0 = constant dense<42> : tensor<2x8xi16>
|
||||
%0 = linalg.tensor_reshape %c0
|
||||
[affine_map<(d0, d1, d2) -> (d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>]
|
||||
: tensor<2x8xi16> into tensor<2x4x2xi16>
|
||||
return %0 : tensor<2x4x2xi16>
|
||||
}
|
||||
// CHECK-LABEL: @reshape_splat_constant_int16
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi16>
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
// CHECK: return %[[CST]]
|
||||
|
||||
func @reshape_splat_constant_float32() -> tensor<2x4x2xf32>
|
||||
{
|
||||
%c0 = constant dense<42.0> : tensor<2x8xf32>
|
||||
%0 = linalg.tensor_reshape %c0
|
||||
[affine_map<(d0, d1, d2) -> (d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>]
|
||||
: tensor<2x8xf32> into tensor<2x4x2xf32>
|
||||
return %0 : tensor<2x4x2xf32>
|
||||
}
|
||||
// CHECK-LABEL: @reshape_splat_constant_float32
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf32>
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
// CHECK: return %[[CST]]
|
||||
|
||||
func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
|
||||
{
|
||||
%c0 = constant dense<42.0> : tensor<2x8xf64>
|
||||
%0 = linalg.tensor_reshape %c0
|
||||
[affine_map<(d0, d1, d2) -> (d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>]
|
||||
: tensor<2x8xf64> into tensor<2x4x2xf64>
|
||||
return %0 : tensor<2x4x2xf64>
|
||||
}
|
||||
// CHECK-LABEL: @reshape_splat_constant_float64
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64>
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
// CHECK: return %[[CST]]
|
||||
|
|
Loading…
Reference in New Issue