Use TestDialect to test traits instead of unittest.

--

PiperOrigin-RevId: 249916947
This commit is contained in:
Jacques Pienaar 2019-05-24 16:17:52 -07:00 committed by Mehdi Amini
parent 9f1f91e770
commit 8b4c214046
4 changed files with 62 additions and 133 deletions

42
mlir/test/IR/traits.mlir Normal file
View File

@ -0,0 +1,42 @@
// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
// CHECK: succeededSameOperandAndResultElementType
func @succeededSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
%0 = "test.same_operand_and_result_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "test.same_operand_and_result_type"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<1xf32>
%2 = "test.same_operand_and_result_type"(%t10x10, %v1) : (tensor<10x10xf32>, vector<1xf32>) -> tensor<1xf32>
%3 = "test.same_operand_and_result_type"(%v1, %t1) : (vector<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%4 = "test.same_operand_and_result_type"(%v1, %t1) : (vector<1xf32>, tensor<1xf32>) -> tensor<121xf32>
return
}
// -----
func @failedSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
// expected-error@+1 {{requires the same element type for all operands and results}}
%0 = "test.same_operand_and_result_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi32>
}
// -----
func @failedSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
// expected-error@+1 {{requires the same element type for all operands and results}}
%0 = "test.same_operand_and_result_type"(%t1, %t1i) : (tensor<1xf32>, tensor<1xi32>) -> tensor<1xf32>
}
// -----
// CHECK: succeededSameOperandAndResultShape
func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) {
%0 = "test.same_operand_and_result_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "test.same_operand_and_result_shape"(%t10x10, %t10x10) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%2 = "test.same_operand_and_result_shape"(%t1, %tr) : (tensor<1xf32>, tensor<*xf32>) -> tensor<1xf32>
return
}
// -----
func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %v1: vector<1xf32>) {
// expected-error@+1 {{requires the same shape for all operands and results}}
%0 = "test.same_operand_and_result_shape"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
}

View File

@ -50,6 +50,8 @@ def VUVFoldTwoResultOp : Pattern<(VUVTwoResultOp $input), [
// Test Types
//===----------------------------------------------------------------------===//
def AnyVectorOrTensor: AnyTypeOf<[AnyVector, AnyTensor]>;
def TupleOp : TEST_Op<"tuple_32_bit"> {
let results = (outs TupleOf<[I32, F32]>);
}
@ -58,4 +60,21 @@ def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> {
let results = (outs NestedTupleOf<[I32, F32]>);
}
#endif // TEST_OPS
//===----------------------------------------------------------------------===//
// Test Traits
//===----------------------------------------------------------------------===//
def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_type",
[SameOperandsAndResultElementType]> {
let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
let results = (outs AnyVectorOrTensor:$res);
}
def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape",
[SameValueShape]> {
let arguments = (ins AnyVectorOrTensor:$x, AnyVectorOrTensor:$y);
let results = (outs AnyVectorOrTensor:$res);
}
#endif // TEST_OPS

View File

@ -1,7 +1,6 @@
add_mlir_unittest(MLIRIRTests
DialectTest.cpp
OperationSupportTest.cpp
OpDefinitionTest.cpp
)
target_link_libraries(MLIRIRTests
PRIVATE

View File

@ -1,131 +0,0 @@
//===- OpDefinitionTest.cpp - Op definition unit tests --------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "gmock/gmock.h"
using namespace mlir;
using namespace mlir::OpTrait::impl;
namespace {
#define FILE_LOC \
FileLineColLoc::get(UniquedFilename::get(__FILE__, &context), __LINE__, 0, \
&context)
// TODO: Replace with regular test once this trait is used by operation in core.
// TODO(b/132891206): Replace with dialect test.
TEST(OpDefinitionTest, SameOperandAndResultElementType) {
MLIRContext context;
Builder b(&context);
auto *operandtF32x10x10 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandtF32x1 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandvF32x1 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getVectorType({1}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandtI32x1 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({1}, b.getIntegerType(32))},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
// Verifies whether an op with x and y as inputs and resultType satisfies the
// SameOperandAndResultElementType trait.
auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
auto op = Operation::create(loc, OperationName("some_op", &context),
/*operands=*/{x->getResult(0), y->getResult(0)},
/*resultTypes=*/{resultType},
/*attributes=*/llvm::None, /*successors=*/{},
/*numRegions=*/0,
/*resizableOperandList=*/false, &context);
return succeeded(verifySameOperandsAndResultElementType(op));
};
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
b.getTensorType({12}, b.getF32Type())));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
b.getTensorType({5}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
b.getTensorType({7}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
b.getTensorType({12}, b.getIntegerType(32))));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
b.getTensorType({9}, b.getIntegerType(32))));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
b.getVectorType({9}, b.getF32Type())));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandvF32x1,
b.getVectorType({9}, b.getF32Type())));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandvF32x1,
b.getTensorType({5}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtI32x1, operandvF32x1,
b.getTensorType({5}, b.getF32Type())));
}
TEST(OpDefinitionTest, SameOperandAndResultShape) {
MLIRContext context;
Builder b(&context);
auto *operandtF32x10x10 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandtF32x1 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandtF32xunranked = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType(b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
// SameOperandAndResultShape trait.
auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
auto op = Operation::create(loc, OperationName("some_op", &context),
/*operands=*/{x->getResult(0), y->getResult(0)},
/*resultTypes=*/{resultType},
/*attributes=*/llvm::None, /*successors=*/{},
/*numRegions=*/0,
/*resizableOperandList=*/false, &context);
return succeeded(verifySameOperandsAndResultShape(op));
};
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
b.getTensorType({1}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
b.getTensorType({12}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x1, operandtF32x10x10,
b.getTensorType({1}, b.getF32Type())));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32xunranked,
b.getTensorType({1}, b.getF32Type())));
}
#undef FILE_LOC
} // end namespace