[DenseElementsAttr] Teach isValidRawBuffer that 1-elt values are splats.

We want getRaw() on tensors with i1 element type with a zero or 1 value
to be treated as a splat.  This fixes:
https://github.com/llvm/llvm-project/issues/55440
This commit is contained in:
Chris Lattner 2022-05-14 11:48:17 +01:00
parent ac7a9ef0ae
commit 5ac9d66209
2 changed files with 26 additions and 126 deletions

View File

@ -809,12 +809,15 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
bool &detectedSplat) {
size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
int64_t numElements = type.getNumElements();
// The initializer is always a splat if the result type has a single element.
detectedSplat = numElements == 1;
// Storage width of 1 is special as it is packed by the bit.
if (storageWidth == 1) {
// Check for a splat, or a buffer equal to the number of elements which
// consists of either all 0's or all 1's.
detectedSplat = false;
if (rawBuffer.size() == 1) {
auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
if (rawByte == 0 || rawByte == 0xff) {
@ -822,12 +825,20 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
return true;
}
}
return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
// This is a valid non-splat buffer if it has the right size.
return rawBufferWidth == llvm::alignTo<8>(numElements);
}
// All other types are 8-bit aligned.
if ((detectedSplat = rawBufferWidth == storageWidth))
// All other types are 8-bit aligned, so we can just check the buffer width
// to know if only a single initializer element was passed in.
if (rawBufferWidth == storageWidth) {
detectedSplat = true;
return true;
return rawBufferWidth == (storageWidth * type.getNumElements());
}
// The raw buffer is valid if it has the right size.
return rawBufferWidth == storageWidth * numElements;
}
/// Check the information for a C++ data type, check if this type is valid for

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
// RUN: mlir-opt --canonicalize %s | FileCheck %s
// CHECK-LABEL: @argmax_nofold
func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
@ -7,8 +7,6 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @add_zero_different_shape
func.func @add_zero_different_shape(%arg0: tensor<2x3xi32>) -> tensor<4x2x3xi32> {
// CHECK: tosa.add
@ -18,8 +16,6 @@ func.func @add_zero_different_shape(%arg0: tensor<2x3xi32>) -> tensor<4x2x3xi32>
}
// -----
// CHECK-LABEL: @add_zero_int
func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK: return %arg0
@ -29,8 +25,6 @@ func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
return %1 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: @cast_fold
func.func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
@ -38,8 +32,6 @@ func.func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @cast_nofold
func.func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
// CHECK: "tosa.cast"
@ -47,8 +39,6 @@ func.func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
return %0 : tensor<?x1xi32>
}
// -----
// CHECK-LABEL: @clamp_not_noop
func.func @clamp_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
// CHECK: "tosa.clamp"
@ -56,8 +46,6 @@ func.func @clamp_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
return %0 : tensor<4xi32>
}
// -----
// CHECK-LABEL: @clamp_float_is_noop
func.func @clamp_float_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: return %arg0
@ -66,8 +54,6 @@ func.func @clamp_float_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @clamp_int8_is_noop
func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: return %arg0
@ -76,8 +62,6 @@ func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
return %0 : tensor<4xi8>
}
// -----
// CHECK-LABEL: @clamp_int16_is_noop
func.func @clamp_int16_is_noop(%arg0: tensor<4xi16>) -> tensor<4xi16> {
// CHECK: return %arg0
@ -86,8 +70,6 @@ func.func @clamp_int16_is_noop(%arg0: tensor<4xi16>) -> tensor<4xi16> {
return %0 : tensor<4xi16>
}
// -----
// CHECK-LABEL: @clamp_uint8_is_noop
func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
// CHECK: return %arg0
@ -96,8 +78,6 @@ func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> {
return %0 : tensor<4xui8>
}
// -----
// CHECK-LABEL: @clamp_twice_is_single_clamp
func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
// CHECK: "tosa.clamp"(%arg0) {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64}
@ -106,8 +86,6 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
return %1 : tensor<4xi8>
}
// -----
// CHECK-LABEL: @concat_fold
func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
@ -115,8 +93,6 @@ func.func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @concat_fold_cast
func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
// CHECK: %[[VAR0:.*]] = tensor.cast %arg0
@ -125,8 +101,6 @@ func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
return %0 : tensor<?x?xf32>
}
// -----
// CHECK-LABEL: @conv2d_stride_2
func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> {
// CHECK: "tosa.conv2d"
@ -136,8 +110,6 @@ func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32
return %0 : tensor<4x10x10x3xf32>
}
// -----
// CHECK-LABEL: @conv2d_weight_2x2
func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> {
// CHECK: "tosa.conv2d"
@ -147,8 +119,6 @@ func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf
return %0 : tensor<4x10x10x1xf32>
}
// -----
// CHECK-LABEL: @depthwise_conv2d_stride_2
func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
// CHECK: "tosa.depthwise_conv2d"
@ -156,8 +126,6 @@ func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor
return %0 : tensor<4x10x10x6xf32>
}
// -----
// CHECK-LABEL: @depthwise_conv2d_weight_2x2
func.func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
// CHECK: "tosa.depthwise_conv2d"
@ -165,8 +133,6 @@ func.func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tens
return %0 : tensor<4x10x10x6xf32>
}
// -----
// CHECK-LABEL: @max_pool2d_is_noop
func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> {
// CHECK-NOT: "tosa.max_pool2d"
@ -175,8 +141,6 @@ func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf3
return %0 : tensor<10x1x1x3xf32>
}
// -----
// CHECK-LABEL: @pad_noop
func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: return %arg0
@ -185,8 +149,6 @@ func.func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
return %1 : tensor<?x?xf32>
}
// -----
// CHECK-LABEL: @pad_determine_val_i32
func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
// CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor<i32>}
@ -196,8 +158,6 @@ func.func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>
return %1 : tensor<?x?xi32>
}
// -----
// CHECK-LABEL: @pad_determine_val_f32
func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
// CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<f32>}
@ -207,8 +167,6 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>
return %1 : tensor<?x?xf32>
}
// -----
// CHECK-LABEL: @pad_determine_val_quant
func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
// CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<42> : tensor<i32>}
@ -218,8 +176,6 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
return %1 : tensor<?x?xi32>
}
// -----
// CHECK-LABEL: @mul_one_different_shape
func.func @mul_one_different_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
// CHECK: tosa.mul
@ -228,8 +184,6 @@ func.func @mul_one_different_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32>
return %1 : tensor<4x2x3xf32>
}
// -----
// CHECK-LABEL: @mul_one_float
func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: return %arg0
@ -239,8 +193,6 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
return %1 : tensor<2x3xf32>
}
// -----
// CHECK-LABEL: @mul_one_int
func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK: return %arg0
@ -250,8 +202,6 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
return %1 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: @select_same_value
func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = "tosa.select"(%arg0, %arg1, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
@ -260,8 +210,6 @@ func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> t
return %0 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: @select_true_value
func.func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
%c1 = "tosa.const"() {value = dense<1> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
@ -271,8 +219,6 @@ func.func @select_true_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) ->
return %0 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: @select_false_value
func.func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
%c0 = "tosa.const"() {value = dense<0> : tensor<2x3xi1>} : () -> tensor<2x3xi1>
@ -282,8 +228,6 @@ func.func @select_false_value(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>) ->
return %0 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: @select_not_pred
func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = "tosa.logical_not"(%arg0) : (tensor<2x3xi1>) -> tensor<2x3xi1>
@ -292,8 +236,6 @@ func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2:
return %1 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: @reduce_all_fold
func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
@ -301,8 +243,6 @@ func.func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @reduce_all_nofold
func.func @reduce_all_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: "tosa.reduce_all"
@ -310,8 +250,6 @@ func.func @reduce_all_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @reduce_any_fold
func.func @reduce_any_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
@ -319,8 +257,6 @@ func.func @reduce_any_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @reduce_any_nofold
func.func @reduce_any_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: "tosa.reduce_any"
@ -328,8 +264,6 @@ func.func @reduce_any_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @reduce_max_fold
func.func @reduce_max_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
@ -337,8 +271,6 @@ func.func @reduce_max_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @reduce_max_nofold
func.func @reduce_max_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: "tosa.reduce_max"
@ -346,8 +278,6 @@ func.func @reduce_max_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @reduce_min_fold
func.func @reduce_min_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
@ -355,8 +285,6 @@ func.func @reduce_min_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @reduce_min_nofold
func.func @reduce_min_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: "tosa.reduce_min"
@ -364,8 +292,6 @@ func.func @reduce_min_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @reduce_prod_fold
func.func @reduce_prod_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
@ -373,8 +299,6 @@ func.func @reduce_prod_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @reduce_prod_nofold
func.func @reduce_prod_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: "tosa.reduce_prod"
@ -382,8 +306,6 @@ func.func @reduce_prod_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @reduce_sum_fold
func.func @reduce_sum_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg0
@ -391,8 +313,6 @@ func.func @reduce_sum_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @reduce_sum_nofold
func.func @reduce_sum_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: "tosa.reduce_sum"
@ -400,8 +320,6 @@ func.func @reduce_sum_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
return %0 : tensor<?x1xf32>
}
// -----
// CHECK-LABEL: @reshape_canonicalize
func.func @reshape_canonicalize(%arg0: tensor<?x10xf32>) -> tensor<?x10xf32> {
// CHECK: return %arg0
@ -409,8 +327,6 @@ func.func @reshape_canonicalize(%arg0: tensor<?x10xf32>) -> tensor<?x10xf32> {
return %0 : tensor<?x10xf32>
}
// -----
// CHECK-LABEL: @reshape_canonicalize_double
func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf32> {
// CHECK: %[[VAR0:.+]] = "tosa.reshape"(%arg0) {new_shape = [-1, 5]}
@ -420,8 +336,6 @@ func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf3
return %1 : tensor<?x5xf32>
}
// -----
// CHECK-LABEL: @reshape_canonicalize_const
func.func @reshape_canonicalize_const() -> tensor<1x10xi32> {
// CHECK: %[[VAR0:.+]] = "tosa.const"() {value = dense<0> : tensor<1x10xi32>}
@ -431,8 +345,6 @@ func.func @reshape_canonicalize_const() -> tensor<1x10xi32> {
return %1 : tensor<1x10xi32>
}
// -----
// CHECK-LABEL: @reshape_canonicalize_const_spat
func.func @reshape_canonicalize_const_spat() -> (tensor<10xi32>, tensor<1x10xi32>) {
// CHECK-DAG: %[[VAR0:.+]] = "tosa.const"() {value = dense<0> : tensor<10xi32>}
@ -443,8 +355,6 @@ func.func @reshape_canonicalize_const_spat() -> (tensor<10xi32>, tensor<1x10xi32
return %0 , %1 : tensor<10xi32>, tensor<1x10xi32>
}
// -----
// CHECK-LABEL: @reshape_canonicalize_const_sparse
func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32>) {
//CHECK: "tosa.reshape"
@ -453,8 +363,6 @@ func.func @reshape_canonicalize_const_sparse() -> (tensor<3xi32>, tensor<1x3xi32
return %0 , %1 : tensor<3xi32>, tensor<1x3xi32>
}
// -----
// CHECK-LABEL: @slice_fold
func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
@ -462,8 +370,6 @@ func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
return %0 : tensor<3x4xf32>
}
// -----
// CHECK-LABEL: @slice_nofold
func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
// CHECK: "tosa.slice"
@ -471,8 +377,6 @@ func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
return %0 : tensor<?x4xf32>
}
// -----
// CHECK-LABEL: @tile_fold
func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
@ -480,8 +384,6 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
return %0 : tensor<3x4xf32>
}
// -----
// CHECK-LABEL: @tile_nofold
func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
// CHECK: "tosa.tile"
@ -489,8 +391,6 @@ func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
return %0 : tensor<3x8xf32>
}
// -----
// CHECK-LABEL: @transpose_fold
func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
@ -499,8 +399,6 @@ func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
return %1 : tensor<3x4xf32>
}
// -----
// CHECK-LABEL: @transpose_nofold
func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
// CHECK: "tosa.transpose"
@ -509,8 +407,6 @@ func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
return %1 : tensor<3x3xf32>
}
// -----
// CHECK-LABEL: @transpose_nofold_shape
func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
// CHECK: "tosa.transpose"
@ -519,8 +415,6 @@ func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
return %1 : tensor<?x?xf32>
}
// -----
// CHECK-LABEL: @transpose_fold_splat
func.func @transpose_fold_splat() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
@ -532,8 +426,6 @@ func.func @transpose_fold_splat() -> tensor<3x2xf32> {
return %1 : tensor<3x2xf32>
}
// -----
// CHECK-LABEL: @transpose_fold_2d_float
func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
@ -545,8 +437,6 @@ func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
return %1 : tensor<3x2xf32>
}
// -----
// CHECK-LABEL: @transpose_fold_4d_int
func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
%input = "tosa.const"() {value = dense<[[
@ -565,8 +455,6 @@ func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
return %1 : tensor<3x1x4x2xi32>
}
// -----
// CHECK-LABEL: @transpose_nofold_non_cst_input
func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
@ -575,8 +463,6 @@ func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2
return %1 : tensor<3x2xf32>
}
// -----
// CHECK-LABEL: @transpose_nofold_non_cst_perms
func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
@ -585,8 +471,6 @@ func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf
return %1 : tensor<3x2xf32>
}
// -----
// CHECK-LABEL: @transpose_nofold_multi_users
func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
@ -596,8 +480,6 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
}
// -----
// CHECK-LABEL: @transpose_nofold_quantized_types
func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> {
%perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
@ -607,8 +489,6 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<
return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
}
// -----
// CHECK-LABEL: @transpose_no_op
func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
// CHECK: return %arg0
@ -617,3 +497,12 @@ func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
%1 = "tosa.transpose"(%arg0, %perms) : (tensor<3x4x5x6xf32>, tensor<4xi32>) -> tensor<3x4x5x6xf32>
return %1 : tensor<3x4x5x6xf32>
}
// CHECK-LABEL: @single_bit_reshape
// https://github.com/llvm/llvm-project/issues/55440
func.func @single_bit_reshape() -> tensor<1xi1> {
// CHECK: "tosa.const"() {value = dense<true> : tensor<1xi1>}
%0 = arith.constant dense<true> : tensor<1x1xi1>
%1 = "tosa.reshape"(%0) {new_shape = [1]} : (tensor<1x1xi1>) -> tensor<1xi1>
return %1 : tensor<1xi1>
}