forked from OSchip/llvm-project
[mlir][tosa] Add some transpose folders
* If the input is a constant splat value, we just need to reshape it. * If the input is a general constant with one user, we can also constant fold it, without bloating the IR. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D110439
This commit is contained in:
parent
5eb6b82729
commit
e325ebb9c7
|
@ -1534,6 +1534,7 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
|
|||
outs Tosa_Tensor1Dto6D:$output
|
||||
);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -159,6 +159,71 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
results.insert<ReshapeReshapeOptimization>(context);
|
||||
}
|
||||
|
||||
struct ConstantTransposeOptimization
|
||||
: public OpRewritePattern<tosa::TransposeOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::TransposeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
DenseElementsAttr inputValues;
|
||||
if (!matchPattern(op.input1(), m_Constant(&inputValues)))
|
||||
return failure();
|
||||
// Make sure the input is a constant that has a single user.
|
||||
if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
|
||||
return failure();
|
||||
|
||||
DenseIntElementsAttr permAttr;
|
||||
if (!matchPattern(op.perms(), m_Constant(&permAttr)))
|
||||
return failure();
|
||||
auto permValues = llvm::to_vector<6>(llvm::map_range(
|
||||
// TOSA allows both 32- and 64-bit integer tensors here.
|
||||
permAttr.getValues<APInt>(),
|
||||
[](const APInt &val) { return val.getZExtValue(); }));
|
||||
|
||||
auto inputType = op.input1().getType().cast<ShapedType>();
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
int64_t numElements = inputType.getNumElements();
|
||||
|
||||
auto outputType = op.getType().cast<ShapedType>();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
|
||||
SmallVector<Attribute, 4> outputValues;
|
||||
outputValues.resize(numElements);
|
||||
|
||||
// Transpose the input constant. Because we don't know its rank in advance,
|
||||
// we need to loop over the range [0, element count) and delinearize the
|
||||
// index.
|
||||
for (int srcLinearIndex = 0; srcLinearIndex < numElements;
|
||||
++srcLinearIndex) {
|
||||
SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
|
||||
int totalCount = srcLinearIndex;
|
||||
for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
|
||||
srcIndices[dim] = totalCount % inputShape[dim];
|
||||
totalCount /= inputShape[dim];
|
||||
}
|
||||
|
||||
SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
|
||||
for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
|
||||
dstIndices[dim] = srcIndices[permValues[dim]];
|
||||
|
||||
uint64_t dstLinearIndex = dstIndices.front();
|
||||
for (int dim = 1; dim < outputType.getRank(); ++dim)
|
||||
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
|
||||
|
||||
outputValues[dstLinearIndex] = inputValues.getValue(srcIndices);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::ConstOp>(
|
||||
op, outputType, DenseElementsAttr::get(outputType, outputValues));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ConstantTransposeOptimization>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operator Folders.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -225,15 +290,18 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
|||
if (!operands[1])
|
||||
return {};
|
||||
|
||||
DenseIntElementsAttr perms = operands[1].cast<DenseIntElementsAttr>();
|
||||
|
||||
bool isRange = true;
|
||||
for (auto it : llvm::enumerate(perms)) {
|
||||
isRange = isRange &&
|
||||
it.value().getSExtValue() == static_cast<int64_t>(it.index());
|
||||
// Transposing splat values just means reshaping.
|
||||
if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
if (input.isSplat())
|
||||
return input.reshape(getType().cast<ShapedType>());
|
||||
}
|
||||
|
||||
if (isRange && input1().getType() == getType())
|
||||
auto perms = llvm::to_vector<6>(llvm::map_range(
|
||||
operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
|
||||
[](const APInt &val) { return val.getSExtValue(); }));
|
||||
|
||||
if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
|
||||
input1().getType() == getType())
|
||||
return input1();
|
||||
return {};
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt --canonicalize %s | FileCheck %s
|
||||
// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @argmax_nofold
|
||||
func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
|
||||
|
@ -237,3 +237,80 @@ func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
|
|||
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @transpose_fold_splat
|
||||
func @transpose_fold_splat() -> tensor<3x2xf32> {
|
||||
%input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: %[[CST:.+]] = "tosa.const"()
|
||||
// CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @transpose_fold_2d_float
|
||||
func @transpose_fold_2d_float() -> tensor<3x2xf32> {
|
||||
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: %[[CST:.+]] = "tosa.const"()
|
||||
// CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @transpose_fold_4d_int
|
||||
func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
|
||||
%input = "tosa.const"() {value = dense<[[
|
||||
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
|
||||
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
|
||||
]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
|
||||
%perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
// CHECK: %[[CST:.+]] = "tosa.const"()
|
||||
// CHECK-SAME{LITERAL}: value = dense<[
|
||||
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
|
||||
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
|
||||
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
|
||||
// CHECK-SAME{LITERAL}: ]>
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
return %1 : tensor<3x1x4x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_non_cst_input
|
||||
func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: tosa.transpose
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_non_cst_perms
|
||||
func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
|
||||
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
|
||||
// CHECK: tosa.transpose
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
return %1 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @transpose_nofold_multi_users
|
||||
func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
|
||||
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
|
||||
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: tosa.transpose
|
||||
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue