[mlir][Linalg] Canonicalize tensor_reshape(splat-constant) -> splat-constant.

When the operand to the linalg.tensor_reshape op is a splat constant,
the result can be replaced with a splat constant of the same value but
different type.

Differential Revision: https://reviews.llvm.org/D86117
This commit is contained in:
MaheshRavishankar 2020-08-18 08:16:25 -07:00
parent 87122c3480
commit a65a50540e
2 changed files with 78 additions and 1 deletions

View File

@ -18,6 +18,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
@ -734,9 +735,28 @@ static LogicalResult verify(TensorReshapeOp op) {
return success();
}
/// Reshape of a splat constant can be replaced with a constant of the result
/// type.
struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
DenseElementsAttr attr;
if (!matchPattern(reshapeOp.src(), m_Constant(&attr)))
return failure();
if (!attr || !attr.isSplat())
return failure();
DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
reshapeOp.getResultType(), attr.getRawData(), true);
rewriter.replaceOpWithNewOp<ConstantOp>(reshapeOp, newAttr);
return success();
}
};
void TensorReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<CollapseReshapeOps<TensorReshapeOp>>(context);
results.insert<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant>(
context);
}
//===----------------------------------------------------------------------===//

View File

@ -203,3 +203,60 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf
// CHECK-NOT: linalg.copy
// CHECK-NEXT: linalg.generic
// -----
func @reshape_splat_constant_int32() -> tensor<2x4x2xi32>
{
%c0 = constant dense<42> : tensor<2x8xi32>
%0 = linalg.tensor_reshape %c0
[affine_map<(d0, d1, d2) -> (d0)>,
affine_map<(d0, d1, d2) -> (d1, d2)>]
: tensor<2x8xi32> into tensor<2x4x2xi32>
return %0 : tensor<2x4x2xi32>
}
// CHECK-LABEL: @reshape_splat_constant_int32
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi32>
// CHECK-NOT: linalg.tensor_reshape
// CHECK: return %[[CST]]
func @reshape_splat_constant_int16() -> tensor<2x4x2xi16>
{
%c0 = constant dense<42> : tensor<2x8xi16>
%0 = linalg.tensor_reshape %c0
[affine_map<(d0, d1, d2) -> (d0)>,
affine_map<(d0, d1, d2) -> (d1, d2)>]
: tensor<2x8xi16> into tensor<2x4x2xi16>
return %0 : tensor<2x4x2xi16>
}
// CHECK-LABEL: @reshape_splat_constant_int16
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi16>
// CHECK-NOT: linalg.tensor_reshape
// CHECK: return %[[CST]]
func @reshape_splat_constant_float32() -> tensor<2x4x2xf32>
{
%c0 = constant dense<42.0> : tensor<2x8xf32>
%0 = linalg.tensor_reshape %c0
[affine_map<(d0, d1, d2) -> (d0)>,
affine_map<(d0, d1, d2) -> (d1, d2)>]
: tensor<2x8xf32> into tensor<2x4x2xf32>
return %0 : tensor<2x4x2xf32>
}
// CHECK-LABEL: @reshape_splat_constant_float32
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf32>
// CHECK-NOT: linalg.tensor_reshape
// CHECK: return %[[CST]]
func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
{
%c0 = constant dense<42.0> : tensor<2x8xf64>
%0 = linalg.tensor_reshape %c0
[affine_map<(d0, d1, d2) -> (d0)>,
affine_map<(d0, d1, d2) -> (d1, d2)>]
: tensor<2x8xf64> into tensor<2x4x2xf64>
return %0 : tensor<2x4x2xf64>
}
// CHECK-LABEL: @reshape_splat_constant_float64
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64>
// CHECK-NOT: linalg.tensor_reshape
// CHECK: return %[[CST]]