[mlir] Move trait to InferTypeOpInterface

Step towards removing the hard coded behavior for this trait and to instead use common interface.

Differential Revision: https://reviews.llvm.org/D114208
This commit is contained in:
Jacques Pienaar 2021-11-21 14:41:11 -08:00
parent c133fb321f
commit 6f9cceb775
29 changed files with 55 additions and 48 deletions

View File

@ -17,6 +17,7 @@
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

View File

@ -10,6 +10,7 @@
#define STANDALONE_OPS
include "Standalone/StandaloneDialect.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def Standalone_FooOp : Standalone_Op<"foo", [NoSideEffect,

View File

@ -11,6 +11,7 @@
include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/IR/OpAsmInterface.td"

View File

@ -10,6 +10,7 @@
#define COMPLEX_OPS
include "mlir/Dialect/Complex/IR/ComplexBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
class Complex_Op<string mnemonic, list<OpTrait> traits = []>
@ -143,10 +144,6 @@ def EqualOp : Complex_Op<"eq",
let arguments = (ins Complex<AnyFloat>:$lhs, Complex<AnyFloat>:$rhs);
let results = (outs I1:$result);
let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{
build($_builder, $_state, $_builder.getI1Type(), lhs, rhs);
}]>];
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
}
@ -292,10 +289,6 @@ def NotEqualOp : Complex_Op<"neq",
let arguments = (ins Complex<AnyFloat>:$lhs, Complex<AnyFloat>:$rhs);
let results = (outs I1:$result);
let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{
build($_builder, $_state, $_builder.getI1Type(), lhs, rhs);
}]>];
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
}

View File

@ -23,6 +23,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {

View File

@ -18,6 +18,7 @@ include "mlir/Dialect/GPU/GPUBase.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//

View File

@ -22,6 +22,7 @@
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/LLVMContext.h"

View File

@ -17,6 +17,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def FMFnnan : BitEnumAttrCase<"nnan", 0x1>;
@ -620,11 +621,6 @@ def LLVM_SelectOp
let arguments = (ins LLVM_ScalarOrVectorOf<I1>:$condition,
LLVM_Type:$trueValue, LLVM_Type:$falseValue);
let results = (outs LLVM_Type:$res);
let builders = [
OpBuilder<(ins "Value":$condition, "Value":$lhs, "Value":$rhs),
[{
build($_builder, $_state, lhs.getType(), condition, lhs, rhs);
}]>];
let assemblyFormat = "operands attr-dict `:` type($condition) `,` type($res)";
}
def LLVM_FreezeOp : LLVM_Op<"freeze", [SameOperandsAndResultType]> {

View File

@ -10,6 +10,7 @@
#define MATH_OPS
include "mlir/Dialect/Math/IR/MathBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

View File

@ -14,6 +14,7 @@
#define DIALECT_QUANT_QUANT_OPS_
include "mlir/Dialect/Quant/QuantOpsBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//

View File

@ -15,6 +15,7 @@
#define MLIR_DIALECT_SPIRV_IR_ARITHMETIC_OPS
include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
class SPV_ArithmeticBinaryOp<string mnemonic, Type type,

View File

@ -1000,9 +1000,6 @@ def SPV_SelectOp : SPV_Op<"Select",
SPV_SelectType:$result
);
let builders = [
OpBuilder<(ins "Value":$cond, "Value":$trueValue, "Value":$falseValue)>];
let assemblyFormat = [{
operands attr-dict `:` type($condition) `,` type($result)
}];

View File

@ -208,8 +208,6 @@ def SPV_GroupNonUniformElectOp : SPV_Op<"GroupNonUniformElect", []> {
SPV_Bool:$result
);
let builders = [OpBuilder<(ins "spirv::Scope")>];
let assemblyFormat = "$execution_scope attr-dict `:` type($result)";
}

View File

@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Support/PointerLikeTypeTraits.h"

View File

@ -11,6 +11,7 @@
include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//

View File

@ -22,6 +22,7 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"

View File

@ -19,6 +19,7 @@ include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
@ -687,12 +688,6 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
let results = (outs Index);
let verifier = ?;
let builders = [
OpBuilder<(ins "Value":$tensor), [{
auto indexType = $_builder.getIndexType();
build($_builder, $_state, indexType, tensor);
}]>];
let hasFolder = 1;
let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)";
}
@ -775,13 +770,6 @@ def SelectOp : Std_Op<"select", [NoSideEffect,
AnyType:$false_value);
let results = (outs AnyType:$result);
let builders = [
OpBuilder<(ins "Value":$condition, "Value":$trueValue,
"Value":$falseValue), [{
$_state.addOperands({condition, trueValue, falseValue});
$_state.addTypes(trueValue.getType());
}]>];
let hasCanonicalizer = 1;
let hasFolder = 1;
}

View File

@ -13,6 +13,7 @@
#ifndef X86VECTOR_OPS
#define X86VECTOR_OPS
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"

View File

@ -17,6 +17,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h.inc"

View File

@ -1952,8 +1952,6 @@ def SameTypeOperands : NativeOpTrait<"SameTypeOperands">;
def SameOperandsShape : NativeOpTrait<"SameOperandsShape">;
// Op has same operand and result shape.
def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">;
// Op has the same operand and result type.
def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;
// Op has the same element type (or type itself, if scalar) for all operands.
def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
// Op has the same operand and result element type (or type itself, if scalar).

View File

@ -178,4 +178,8 @@ def ReifyRankedShapedTypeOpInterface :
];
}
// Op has the same operand and result type.
// TODO: Change from hard coded to utilizing type inference trait.
def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;
#endif // MLIR_INFERTYPEOPINTERFACE

View File

@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLLVMIR
MLIRCallInterfaces
MLIRControlFlowInterfaces
MLIRDataLayoutInterfaces
MLIRInferTypeOpInterface
MLIRIR
MLIRSideEffectInterfaces
MLIRSupport

View File

@ -2395,12 +2395,6 @@ static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) {
// spv.GroupNonUniformElectOp
//===----------------------------------------------------------------------===//
void spirv::GroupNonUniformElectOp::build(OpBuilder &builder,
OperationState &state,
spirv::Scope scope) {
build(builder, state, builder.getI1Type(), scope);
}
static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
spirv::Scope scope = groupOp.execution_scope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
@ -2849,11 +2843,6 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) {
// spv.Select
//===----------------------------------------------------------------------===//
void spirv::SelectOp::build(OpBuilder &builder, OperationState &state,
Value cond, Value trueValue, Value falseValue) {
build(builder, state, trueValue.getType(), cond, trueValue, falseValue);
}
static LogicalResult verify(spirv::SelectOp op) {
if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
auto resultVectorTy = op.result().getType().dyn_cast<VectorType>();

View File

@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRStandard
MLIRCallInterfaces
MLIRCastInterfaces
MLIRControlFlowInterfaces
MLIRInferTypeOpInterface
MLIRIR
MLIRSideEffectInterfaces
MLIRVectorInterfaces

View File

@ -15,6 +15,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
using namespace mlir;

View File

@ -25,7 +25,11 @@ func @bit_field_insert_vec(%base: vector<3xi32>, %insert: vector<3xi32>, %offset
// -----
func @bit_field_insert_invalid_insert_type(%base: vector<3xi32>, %insert: vector<2xi32>, %offset: i32, %count: i16) -> vector<3xi32> {
// expected-error @+1 {{all of {base, insert, result} have same type}}
// TODO: expand post change in verification order. This is currently only
// verifying that the type verification is failing but not the specific error
// message. In final state the error should refer to mismatch in base and
// insert.
// expected-error @+1 {{type}}
%0 = "spv.BitFieldInsert" (%base, %insert, %offset, %count) : (vector<3xi32>, vector<2xi32>, i32, i16) -> vector<3xi32>
spv.ReturnValue %0 : vector<3xi32>
}
@ -55,7 +59,7 @@ func @bit_field_u_extract_vec(%base: vector<3xi32>, %offset: i8, %count: i8) ->
// -----
func @bit_field_u_extract_invalid_result_type(%base: vector<3xi32>, %offset: i32, %count: i16) -> vector<4xi32> {
// expected-error @+1 {{failed to verify that all of {base, result} have same type}}
// expected-error @+1 {{inferred type(s) 'vector<3xi32>' are incompatible with return type(s) of operation 'vector<4xi32>'}}
%0 = "spv.BitFieldUExtract" (%base, %offset, %count) : (vector<3xi32>, i32, i16) -> vector<4xi32>
spv.ReturnValue %0 : vector<4xi32>
}

View File

@ -270,7 +270,11 @@ func @select_op(%arg1: vector<4xi1>) -> () {
func @select_op(%arg1: vector<4xi1>) -> () {
%0 = spv.Constant dense<[2.0, 3.0, 4.0]> : vector<3xf32>
%1 = spv.Constant dense<[5, 6, 7]> : vector<3xi32>
// expected-error @+1 {{all of {true_value, false_value, result} have same type}}
// TODO: expand post change in verification order. This is currently only
// verifying that the type verification is failing but not the specific error
// message. In final state the error should refer to mismatch in true_value and
// false_value.
// expected-error @+1 {{type}}
%2 = "spv.Select"(%arg1, %1, %0) : (vector<4xi1>, vector<3xi32>, vector<3xf32>) -> vector<3xi32>
return
}

View File

@ -137,7 +137,11 @@ func @func_with_ops(i32, i32, i32) {
func @func_with_ops(i1, i32, i64) {
^bb0(%cond : i1, %t : i32, %f : i64):
// expected-error@+1 {{all of {true_value, false_value, result} have same type}}
// TODO: expand post change in verification order. This is currently only
// verifying that the type verification is failing but not the specific error
// message. In final state the error should refer to mismatch in true_value and
// false_value.
// expected-error@+1 {{type}}
%r = "std.select"(%cond, %t, %f) : (i1, i32, i64) -> i32
}

View File

@ -1500,6 +1500,7 @@ td_library(
srcs = ["include/mlir/Dialect/X86Vector/X86Vector.td"],
includes = ["include"],
deps = [
":InferTypeOpInterfaceTdFiles",
":LLVMOpsTdFiles",
":SideEffectInterfacesTdFiles",
],
@ -1548,6 +1549,7 @@ cc_library(
includes = ["include"],
deps = [
":IR",
":InferTypeOpInterface",
":LLVMDialect",
":SideEffectInterfaces",
":X86VectorIncGen",
@ -1688,6 +1690,7 @@ td_library(
],
includes = ["include"],
deps = [
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
],
@ -1788,6 +1791,7 @@ cc_library(
deps = [
":ArithmeticDialect",
":IR",
":InferTypeOpInterface",
":SideEffectInterfaces",
":SparseTensorAttrDefsIncGen",
":SparseTensorOpsIncGen",
@ -1856,6 +1860,7 @@ td_library(
":CallInterfacesTdFiles",
":CastInterfacesTdFiles",
":ControlFlowInterfacesTdFiles",
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
":VectorInterfacesTdFiles",
@ -2519,6 +2524,7 @@ cc_library(
":CommonFolders",
":ControlFlowInterfaces",
":IR",
":InferTypeOpInterface",
":SideEffectInterfaces",
":StandardOpsIncGen",
":Support",
@ -2750,6 +2756,7 @@ cc_library(
":ControlFlowInterfaces",
":DataLayoutInterfaces",
":IR",
":InferTypeOpInterface",
":LLVMDialectAttributesIncGen",
":LLVMDialectInterfaceIncGen",
":LLVMOpsIncGen",
@ -2915,6 +2922,7 @@ cc_library(
":GPUBaseIncGen",
":GPUOpsIncGen",
":IR",
":InferTypeOpInterface",
":LLVMDialect",
":MemRefDialect",
":SideEffectInterfaces",
@ -3011,6 +3019,7 @@ td_library(
includes = ["include"],
deps = [
":ControlFlowInterfacesTdFiles",
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
],
@ -3669,6 +3678,7 @@ td_library(
deps = [
":CallInterfacesTdFiles",
":ControlFlowInterfacesTdFiles",
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
],
@ -3819,6 +3829,7 @@ cc_library(
":CommonFolders",
":ControlFlowInterfaces",
":IR",
":InferTypeOpInterface",
":Parser",
":Pass",
":SPIRVAttrUtilsGen",
@ -6037,6 +6048,7 @@ td_library(
],
includes = ["include"],
deps = [
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
],
@ -6939,6 +6951,7 @@ td_library(
],
includes = ["include"],
deps = [
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
],
@ -7073,6 +7086,7 @@ td_library(
includes = ["include"],
deps = [
":CastInterfacesTdFiles",
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
":VectorInterfacesTdFiles",
@ -7218,6 +7232,7 @@ td_library(
],
includes = ["include"],
deps = [
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
":VectorInterfacesTdFiles",