forked from OSchip/llvm-project
Use TestDialect to test traits instead of unittest.
-- PiperOrigin-RevId: 249916947
This commit is contained in:
parent
9f1f91e770
commit
8b4c214046
|
@ -0,0 +1,42 @@
|
|||
// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
|
||||
|
||||
// CHECK: succeededSameOperandAndResultElementType
|
||||
func @succeededSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
|
||||
%0 = "test.same_operand_and_result_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%1 = "test.same_operand_and_result_type"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<1xf32>
|
||||
%2 = "test.same_operand_and_result_type"(%t10x10, %v1) : (tensor<10x10xf32>, vector<1xf32>) -> tensor<1xf32>
|
||||
%3 = "test.same_operand_and_result_type"(%v1, %t1) : (vector<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%4 = "test.same_operand_and_result_type"(%v1, %t1) : (vector<1xf32>, tensor<1xf32>) -> tensor<121xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
|
||||
// expected-error@+1 {{requires the same element type for all operands and results}}
|
||||
%0 = "test.same_operand_and_result_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
|
||||
// expected-error@+1 {{requires the same element type for all operands and results}}
|
||||
%0 = "test.same_operand_and_result_type"(%t1, %t1i) : (tensor<1xf32>, tensor<1xi32>) -> tensor<1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: succeededSameOperandAndResultShape
|
||||
func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) {
|
||||
%0 = "test.same_operand_and_result_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%1 = "test.same_operand_and_result_shape"(%t10x10, %t10x10) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
%2 = "test.same_operand_and_result_shape"(%t1, %tr) : (tensor<1xf32>, tensor<*xf32>) -> tensor<1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>) {
|
||||
// expected-error@+1 {{requires the same shape for all operands and results}}
|
||||
%0 = "test.same_operand_and_result_shape"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||
}
|
|
@ -50,6 +50,8 @@ def VUVFoldTwoResultOp : Pattern<(VUVTwoResultOp $input), [
|
|||
// Test Types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AnyVectorOrTensor: AnyTypeOf<[AnyVector, AnyTensor]>;
|
||||
|
||||
def TupleOp : TEST_Op<"tuple_32_bit"> {
|
||||
let results = (outs TupleOf<[I32, F32]>);
|
||||
}
|
||||
|
@ -58,4 +60,21 @@ def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> {
|
|||
let results = (outs NestedTupleOf<[I32, F32]>);
|
||||
}
|
||||
|
||||
#endif // TEST_OPS
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Traits
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_type",
|
||||
[SameOperandsAndResultElementType]> {
|
||||
let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
|
||||
let results = (outs AnyVectorOrTensor:$res);
|
||||
}
|
||||
|
||||
def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape",
|
||||
[SameValueShape]> {
|
||||
let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
|
||||
let results = (outs AnyVectorOrTensor:$res);
|
||||
}
|
||||
|
||||
#endif // TEST_OPS
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
add_mlir_unittest(MLIRIRTests
|
||||
DialectTest.cpp
|
||||
OperationSupportTest.cpp
|
||||
OpDefinitionTest.cpp
|
||||
)
|
||||
target_link_libraries(MLIRIRTests
|
||||
PRIVATE
|
||||
|
|
|
@ -1,131 +0,0 @@
|
|||
//===- OpDefinitionTest.cpp - Op definition unit tests --------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::OpTrait::impl;
|
||||
|
||||
namespace {
|
||||
|
||||
#define FILE_LOC \
|
||||
FileLineColLoc::get(UniquedFilename::get(__FILE__, &context), __LINE__, 0, \
|
||||
&context)
|
||||
|
||||
// TODO: Replace with regular test once this trait is used by operation in core.
|
||||
// TODO(b/132891206): Replace with dialect test.
|
||||
TEST(OpDefinitionTest, SameOperandAndResultElementType) {
|
||||
MLIRContext context;
|
||||
Builder b(&context);
|
||||
auto *operandtF32x10x10 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
auto *operandtF32x1 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
auto *operandvF32x1 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getVectorType({1}, b.getF32Type())},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
auto *operandtI32x1 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getTensorType({1}, b.getIntegerType(32))},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
|
||||
// Verifies whether an op with x and y as inputs and resultType satisfies the
|
||||
// SameOperandAndResultElementType trait.
|
||||
auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
|
||||
auto op = Operation::create(loc, OperationName("some_op", &context),
|
||||
/*operands=*/{x->getResult(0), y->getResult(0)},
|
||||
/*resultTypes=*/{resultType},
|
||||
/*attributes=*/llvm::None, /*successors=*/{},
|
||||
/*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
return succeeded(verifySameOperandsAndResultElementType(op));
|
||||
};
|
||||
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
|
||||
b.getTensorType({12}, b.getF32Type())));
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
|
||||
b.getTensorType({5}, b.getF32Type())));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
|
||||
b.getTensorType({7}, b.getF32Type())));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
|
||||
b.getTensorType({12}, b.getIntegerType(32))));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
|
||||
b.getTensorType({9}, b.getIntegerType(32))));
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
|
||||
b.getVectorType({9}, b.getF32Type())));
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandvF32x1,
|
||||
b.getVectorType({9}, b.getF32Type())));
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandvF32x1,
|
||||
b.getTensorType({5}, b.getF32Type())));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtI32x1, operandvF32x1,
|
||||
b.getTensorType({5}, b.getF32Type())));
|
||||
}
|
||||
|
||||
TEST(OpDefinitionTest, SameOperandAndResultShape) {
|
||||
MLIRContext context;
|
||||
Builder b(&context);
|
||||
auto *operandtF32x10x10 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
auto *operandtF32x1 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
auto *operandtF32xunranked = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getTensorType(b.getF32Type())},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
|
||||
// SameOperandAndResultShape trait.
|
||||
auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
|
||||
auto op = Operation::create(loc, OperationName("some_op", &context),
|
||||
/*operands=*/{x->getResult(0), y->getResult(0)},
|
||||
/*resultTypes=*/{resultType},
|
||||
/*attributes=*/llvm::None, /*successors=*/{},
|
||||
/*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
return succeeded(verifySameOperandsAndResultShape(op));
|
||||
};
|
||||
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
|
||||
b.getTensorType({1}, b.getF32Type())));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
|
||||
b.getTensorType({12}, b.getF32Type())));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtF32x1, operandtF32x10x10,
|
||||
b.getTensorType({1}, b.getF32Type())));
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32xunranked,
|
||||
b.getTensorType({1}, b.getF32Type())));
|
||||
}
|
||||
|
||||
#undef FILE_LOC
|
||||
} // end namespace
|
Loading…
Reference in New Issue