forked from OSchip/llvm-project
[mlir][Tensor] NFC - Add result pretty printing to TensorOps
Differential Revision: https://reviews.llvm.org/D135135
This commit is contained in:
parent
42ad305bdb
commit
54a4e9685d
|
@ -18,6 +18,7 @@ include "mlir/Interfaces/ShapedOpInterfaces.td"
|
|||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/TilingInterface.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
|
||||
class Tensor_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<Tensor_Dialect, mnemonic, traits>;
|
||||
|
@ -46,7 +47,9 @@ class Tensor_OpWithOffsetSizesAndStrides<string mnemonic,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_CastOp : Tensor_Op<"cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>, NoSideEffect
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
NoSideEffect
|
||||
]> {
|
||||
let summary = "tensor cast operation";
|
||||
let description = [{
|
||||
|
@ -82,7 +85,10 @@ def Tensor_CastOp : Tensor_Op<"cast", [
|
|||
// DimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect, ShapedDimOpInterface]> {
|
||||
def Tensor_DimOp : Tensor_Op<"dim", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
NoSideEffect,
|
||||
ShapedDimOpInterface]> {
|
||||
let summary = "dimension index operation";
|
||||
let description = [{
|
||||
The `tensor.dim` operation takes a tensor and a dimension operand of type
|
||||
|
@ -199,11 +205,12 @@ def Tensor_EmptyOp : Tensor_Op<"empty",
|
|||
// ExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_ExtractOp : Tensor_Op<"extract",
|
||||
[NoSideEffect,
|
||||
TypesMatchWith<"result type matches element type of tensor",
|
||||
"tensor", "result",
|
||||
"$_self.cast<ShapedType>().getElementType()">]> {
|
||||
def Tensor_ExtractOp : Tensor_Op<"extract", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
NoSideEffect,
|
||||
TypesMatchWith<"result type matches element type of tensor",
|
||||
"tensor", "result",
|
||||
"$_self.cast<ShapedType>().getElementType()">]> {
|
||||
let summary = "element extraction operation";
|
||||
let description = [{
|
||||
The `tensor.extract` op reads a tensor and returns one
|
||||
|
@ -242,8 +249,10 @@ def Tensor_ExtractOp : Tensor_Op<"extract",
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", [
|
||||
NoSideEffect, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
|
||||
AttrSizedOperandSegments,
|
||||
NoSideEffect,
|
||||
OffsetSizeAndStrideOpInterface
|
||||
]> {
|
||||
let summary = "extract slice operation";
|
||||
|
@ -436,6 +445,7 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
NoSideEffect,
|
||||
TypesMatchWith<"operand types match result element type",
|
||||
"result", "elements", "SmallVector<Type, 2>("
|
||||
|
@ -481,6 +491,7 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_GatherOp : Tensor_Op<"gather", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
NoSideEffect
|
||||
]> {
|
||||
let summary = "gather a subset of a tensor at specified indices";
|
||||
|
@ -618,10 +629,11 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
|
|||
// GenerateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_GenerateOp : Tensor_Op<"generate",
|
||||
[RecursiveSideEffects,
|
||||
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
|
||||
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
|
||||
def Tensor_GenerateOp : Tensor_Op<"generate", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
RecursiveSideEffects,
|
||||
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
|
||||
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
|
||||
let summary = "Creates a dynamically sized tensor from elements";
|
||||
let description = [{
|
||||
This operation creates a dynamically sized tensor with elements of any type.
|
||||
|
@ -664,14 +676,15 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
|
|||
// InsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_InsertOp : Tensor_Op<"insert",
|
||||
[NoSideEffect,
|
||||
TypesMatchWith<"result type matches type of dest",
|
||||
"dest", "result",
|
||||
"$_self.cast<ShapedType>()">,
|
||||
TypesMatchWith<"scalar type matches element type of dest",
|
||||
"dest", "scalar",
|
||||
"$_self.cast<ShapedType>().getElementType()">]> {
|
||||
def Tensor_InsertOp : Tensor_Op<"insert", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
NoSideEffect,
|
||||
TypesMatchWith<"result type matches type of dest",
|
||||
"dest", "result",
|
||||
"$_self.cast<ShapedType>()">,
|
||||
TypesMatchWith<"scalar type matches element type of dest",
|
||||
"dest", "scalar",
|
||||
"$_self.cast<ShapedType>().getElementType()">]> {
|
||||
let summary = "element insertion operation";
|
||||
let description = [{
|
||||
The `tensor.insert` op writes a tensor into a tensor `dest`as specified by
|
||||
|
@ -717,8 +730,11 @@ def Tensor_InsertOp : Tensor_Op<"insert",
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
|
||||
NoSideEffect, AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
|
||||
AttrSizedOperandSegments,
|
||||
NoSideEffect,
|
||||
OffsetSizeAndStrideOpInterface,
|
||||
TypesMatchWith<"expected result type to match dest type",
|
||||
"dest", "result", "$_self">
|
||||
]> {
|
||||
|
@ -854,7 +870,9 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
|
|||
// RankOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_RankOp : Tensor_Op<"rank", [NoSideEffect]> {
|
||||
def Tensor_RankOp : Tensor_Op<"rank", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
NoSideEffect]> {
|
||||
let summary = "rank operation";
|
||||
let description = [{
|
||||
The `tensor.rank` operation takes a tensor operand and returns its rank.
|
||||
|
@ -878,7 +896,9 @@ def Tensor_RankOp : Tensor_Op<"rank", [NoSideEffect]> {
|
|||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_ReshapeOp: Tensor_Op<"reshape", [NoSideEffect]> {
|
||||
def Tensor_ReshapeOp: Tensor_Op<"reshape", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
NoSideEffect]> {
|
||||
let summary = "tensor reshape operation";
|
||||
let description = [{
|
||||
The `reshape` operation converts a tensor from one type to an equivalent
|
||||
|
@ -941,7 +961,9 @@ def Tensor_ReshapeOp: Tensor_Op<"reshape", [NoSideEffect]> {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
|
||||
Tensor_Op<mnemonic, !listconcat(traits, [NoSideEffect])>,
|
||||
Tensor_Op<mnemonic, !listconcat(traits, [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
NoSideEffect])>,
|
||||
Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>,
|
||||
Results<(outs AnyTensor:$result)> {
|
||||
|
||||
|
@ -1091,7 +1113,10 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
|
|||
// PadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
|
||||
def Tensor_PadOp : Tensor_Op<"pad", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
AttrSizedOperandSegments,
|
||||
NoSideEffect,
|
||||
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
|
||||
let summary = "tensor pad operation";
|
||||
let description = [{
|
||||
|
@ -1433,6 +1458,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_ScatterOp : Tensor_Op<"scatter", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
NoSideEffect
|
||||
]> {
|
||||
let summary =
|
||||
|
@ -1573,6 +1599,7 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tensor_SplatOp : Tensor_Op<"splat", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
NoSideEffect,
|
||||
TypesMatchWith<"operand type matches element type of result",
|
||||
"aggregate", "input",
|
||||
|
|
|
@ -58,6 +58,10 @@ SmallVector<OpFoldResult> tensor::getMixedSizes(OpBuilder &builder,
|
|||
// CastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "cast");
|
||||
}
|
||||
|
||||
/// Returns true if `target` is a ranked tensor type that preserves static
|
||||
/// information available in the `source` ranked tensor type.
|
||||
bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
|
||||
|
@ -307,6 +311,10 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
// DimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "dim");
|
||||
}
|
||||
|
||||
void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
|
||||
int64_t index) {
|
||||
auto loc = result.location;
|
||||
|
@ -697,6 +705,11 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
// ExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ExtractOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "extracted");
|
||||
}
|
||||
|
||||
LogicalResult ExtractOp::verify() {
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
if (auto tensorType = getTensor().getType().dyn_cast<RankedTensorType>())
|
||||
|
@ -756,6 +769,11 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
|||
// FromElementsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void FromElementsOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "from_elements");
|
||||
}
|
||||
|
||||
void FromElementsOp::build(OpBuilder &builder, OperationState &result,
|
||||
Type resultType, ValueRange elements) {
|
||||
result.addOperands(elements);
|
||||
|
@ -828,6 +846,11 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
// GatherOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void GatherOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "gather");
|
||||
}
|
||||
|
||||
/// Return the inferred result type for a gatherOp where:
|
||||
/// - sourceType is the type of the source tensor gathered from
|
||||
/// - indicesType is the type of the indices used to gather
|
||||
|
@ -911,6 +934,11 @@ LogicalResult GatherOp::verify() {
|
|||
// InsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void InsertOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "inserted");
|
||||
}
|
||||
|
||||
LogicalResult InsertOp::verify() {
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
if (auto destType = getDest().getType().dyn_cast<RankedTensorType>())
|
||||
|
@ -933,6 +961,11 @@ OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
|
|||
// GenerateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void GenerateOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "generated");
|
||||
}
|
||||
|
||||
LogicalResult GenerateOp::reifyResultShapes(
|
||||
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
||||
reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
|
||||
|
@ -1116,6 +1149,10 @@ void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
// RankOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "rank");
|
||||
}
|
||||
|
||||
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Constant fold rank when the rank of the operand is known.
|
||||
auto type = getOperand().getType();
|
||||
|
@ -1129,6 +1166,11 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
|||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ReshapeOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "reshape");
|
||||
}
|
||||
|
||||
static int64_t getNumElements(ShapedType type) {
|
||||
int64_t numElements = 1;
|
||||
for (auto dim : type.getShape())
|
||||
|
@ -1170,6 +1212,16 @@ LogicalResult ReshapeOp::verify() {
|
|||
// Reassociative reshape ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void CollapseShapeOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "collapsed");
|
||||
}
|
||||
|
||||
void ExpandShapeOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "expanded");
|
||||
}
|
||||
|
||||
SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
|
||||
return getSymbolLessAffineMaps(getReassociationExprs());
|
||||
}
|
||||
|
@ -1369,6 +1421,11 @@ OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
|
|||
// ExtractSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ExtractSliceOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "extracted_slice");
|
||||
}
|
||||
|
||||
/// An extract_slice result type can be inferred, when it is not
|
||||
/// rank-reduced, from the source type and the static representation of
|
||||
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
|
||||
|
@ -1865,6 +1922,11 @@ Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
|
|||
// InsertSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void InsertSliceOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "inserted_slice");
|
||||
}
|
||||
|
||||
// Build a InsertSliceOp with mixed static and dynamic entries.
|
||||
void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
|
||||
Value dest, ArrayRef<OpFoldResult> offsets,
|
||||
|
@ -2218,6 +2280,10 @@ Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
|
|||
// PadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "padded");
|
||||
}
|
||||
|
||||
// TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
|
||||
// supports optional types.
|
||||
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
|
||||
|
@ -2725,6 +2791,11 @@ void ParallelInsertSliceOp::getCanonicalizationPatterns(
|
|||
// ScatterOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ScatterOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "scatter");
|
||||
}
|
||||
|
||||
LogicalResult ScatterOp::verify() {
|
||||
int64_t destRank = getDestType().getRank();
|
||||
ArrayRef<int64_t> scatterDims = getScatterDims();
|
||||
|
@ -2761,6 +2832,11 @@ LogicalResult ScatterOp::verify() {
|
|||
// SplatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SplatOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "splat");
|
||||
}
|
||||
|
||||
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = operands.front();
|
||||
if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
|
||||
|
|
|
@ -595,7 +595,7 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
|
|||
// CHECK: ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
|
||||
// CHECK: tensor.yield %cst : f32
|
||||
// CHECK: } : tensor<2x?x?x3xf32> to tensor<2x?x?x3xf32>
|
||||
// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} ins(%[[PADDED]], %arg1 : tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>) outs(%22 : tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x3x5xf32>
|
||||
// CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} ins(%[[PADDED]], %arg1 : tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>) outs(%{{.*}} : tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x3x5xf32>
|
||||
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[CONV]] {{\[}}[0], [1], [2], [3, 4]]
|
||||
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 2, 3, 4], dilation = [2, 1], stride = [1, 2]} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32>
|
||||
return
|
||||
|
|
|
@ -604,8 +604,8 @@ func.func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6
|
|||
|
||||
// CHECK-LABEL: @test_reshape_downrank_6D_dyn
|
||||
func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
|
||||
// CHECK: tensor.collapse_shape %arg0 {{\[}}[0, 1, 2, 3, 4, 5]]
|
||||
// CHECK: tensor.expand_shape %0 {{\[}}[0, 1, 2]]
|
||||
// CHECK: tensor.collapse_shape {{.*}}[0, 1, 2, 3, 4, 5]
|
||||
// CHECK: tensor.expand_shape {{.*}}[0, 1, 2]
|
||||
%0 = "tosa.reshape"(%arg0) {new_shape = [-1, 5, 77]} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
|
||||
return %0 : tensor<?x5x77xf32>
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ func.func @slice_dyn(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
|
|||
// CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
|
||||
// CHECK: %[[C2:.+]] = arith.constant 2 : index
|
||||
// CHECK: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C2]]
|
||||
// CHECK: %2 = tensor.extract_slice %arg0[2] [%[[SUB]]] [1]
|
||||
// CHECK: tensor.extract_slice %arg0[2] [%[[SUB]]] [1]
|
||||
%0 = "tosa.slice"(%arg0) {start = [2], size = [-1]} : (tensor<?xf32>) -> (tensor<?xf32>)
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
|
|
@ -5,13 +5,13 @@ func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: ten
|
|||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
|
||||
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<16x4x64xf32>, tensor<4x64x32xf32>)
|
||||
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>) {
|
||||
// CHECK-SAME: ins(%{{[a-zA-Z0-9_]*}}, %{{[a-zA-Z0-9_]*}} : tensor<16x4x64xf32>, tensor<4x64x32xf32>)
|
||||
// CHECK-SAME: outs(%{{[a-zA-Z0-9_]*}} : tensor<16x32x4xf32>) {
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
|
||||
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>)
|
||||
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32xf32>) {
|
||||
// CHECK-SAME: ins(%{{[a-zA-Z0-9_]*}} : tensor<16x32x4xf32>)
|
||||
// CHECK-SAME: outs(%{{[a-zA-Z0-9_]*}} : tensor<16x32xf32>) {
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
|
||||
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
|
||||
return %0: tensor<16x32xf32>
|
||||
|
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
// CHECK-LABEL: func @cast(
|
||||
func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?xf32>) {
|
||||
// CHECK: tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
// CHECK: tensor.cast %{{.*}} : tensor<*xf32> to tensor<?x?xf32>
|
||||
%0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
|
||||
// CHECK: tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
|
||||
// CHECK: tensor.cast %{{.*}} : tensor<4x4xf32> to tensor<*xf32>
|
||||
%1 = tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32>
|
||||
// CHECK: tensor.cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
|
||||
// CHECK: tensor.cast %{{.*}} : tensor<?x?xf32> to tensor<4x?xf32>
|
||||
%2 = tensor.cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
|
||||
// CHECK: tensor.cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: tensor.cast %{{.*}} : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
%3 = tensor.cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue