[mlir][tosa] Add some transpose folders

* If the input is a constant splat value, we just
  need to reshape it.
* If the input is a general constant with one user,
  we can also constant fold it, without bloating
  the IR.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D110439
This commit is contained in:
Lei Zhang 2021-09-24 15:21:11 -04:00
parent 5eb6b82729
commit e325ebb9c7
3 changed files with 154 additions and 8 deletions

View File

@ -1534,6 +1534,7 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
outs Tosa_Tensor1Dto6D:$output
);
let hasCanonicalizer = 1;
let hasFolder = 1;
}

View File

@ -159,6 +159,71 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<ReshapeReshapeOptimization>(context);
}
struct ConstantTransposeOptimization
: public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
DenseElementsAttr inputValues;
if (!matchPattern(op.input1(), m_Constant(&inputValues)))
return failure();
// Make sure the input is a constant that has a single user.
if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
return failure();
DenseIntElementsAttr permAttr;
if (!matchPattern(op.perms(), m_Constant(&permAttr)))
return failure();
auto permValues = llvm::to_vector<6>(llvm::map_range(
// TOSA allows both 32- and 64-bit integer tensors here.
permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getZExtValue(); }));
auto inputType = op.input1().getType().cast<ShapedType>();
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t numElements = inputType.getNumElements();
auto outputType = op.getType().cast<ShapedType>();
ArrayRef<int64_t> outputShape = outputType.getShape();
SmallVector<Attribute, 4> outputValues;
outputValues.resize(numElements);
// Transpose the input constant. Because we don't know its rank in advance,
// we need to loop over the range [0, element count) and delinearize the
// index.
for (int srcLinearIndex = 0; srcLinearIndex < numElements;
++srcLinearIndex) {
SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
int totalCount = srcLinearIndex;
for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
srcIndices[dim] = totalCount % inputShape[dim];
totalCount /= inputShape[dim];
}
SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
dstIndices[dim] = srcIndices[permValues[dim]];
uint64_t dstLinearIndex = dstIndices.front();
for (int dim = 1; dim < outputType.getRank(); ++dim)
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
outputValues[dstLinearIndex] = inputValues.getValue(srcIndices);
}
rewriter.replaceOpWithNewOp<tosa::ConstOp>(
op, outputType, DenseElementsAttr::get(outputType, outputValues));
return success();
}
};
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ConstantTransposeOptimization>(context);
}
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
@ -225,15 +290,18 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
if (!operands[1])
return {};
DenseIntElementsAttr perms = operands[1].cast<DenseIntElementsAttr>();
bool isRange = true;
for (auto it : llvm::enumerate(perms)) {
isRange = isRange &&
it.value().getSExtValue() == static_cast<int64_t>(it.index());
// Transposing splat values just means reshaping.
if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
if (input.isSplat())
return input.reshape(getType().cast<ShapedType>());
}
if (isRange && input1().getType() == getType())
auto perms = llvm::to_vector<6>(llvm::map_range(
operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); }));
if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
input1().getType() == getType())
return input1();
return {};
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt --canonicalize %s | FileCheck %s
// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
// CHECK-LABEL: @argmax_nofold
func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
@ -237,3 +237,80 @@ func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// -----
// CHECK-LABEL: @transpose_fold_splat
func @transpose_fold_splat() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"()
// CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}
// -----
// CHECK-LABEL: @transpose_fold_2d_float
func @transpose_fold_2d_float() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"()
// CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}
// -----
// CHECK-LABEL: @transpose_fold_4d_int
func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
%input = "tosa.const"() {value = dense<[[
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
%perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
// CHECK: %[[CST:.+]] = "tosa.const"()
// CHECK-SAME{LITERAL}: value = dense<[
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
// CHECK-SAME{LITERAL}: ]>
%1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
// CHECK: return %[[CST]]
return %1 : tensor<3x1x4x2xi32>
}
// -----
// CHECK-LABEL: @transpose_nofold_non_cst_input
func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: tosa.transpose
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
// -----
// CHECK-LABEL: @transpose_nofold_non_cst_perms
func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
// CHECK: tosa.transpose
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
// -----
// CHECK-LABEL: @transpose_nofold_multi_users
func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: tosa.transpose
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
}