forked from OSchip/llvm-project
Add InferTypeOpTrait & enable generating its member function definition
Use OpInterfaces to add an interface for ops defining a return type function. This change does not use this trait in any meaningful way, I'll use it in a follow up to generalize and unify some of the op type traits/constraints. Also, currently the infer type function can only be manually specified in C++, that should rather be the fallback in future. PiperOrigin-RevId: 271883746
This commit is contained in:
parent
f45a392566
commit
e5a43186d3
|
@ -2,3 +2,8 @@ set(LLVM_TARGET_DEFINITIONS CallInterfaces.td)
|
|||
mlir_tablegen(CallInterfaces.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(CallInterfaces.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(MLIRCallOpInterfacesIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS InferTypeOpInterface.td)
|
||||
mlir_tablegen(InferTypeOpInterface.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(InferTypeOpInterface.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(MLIRTypeInferOpInterfaceIncGen)
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
//===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains the definitions of the infer op interfaces defined in
|
||||
// `InferTypeOpInterface.td`.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
|
||||
#define MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
#include "mlir/Analysis/InferTypeOpInterface.h.inc"
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
|
|
@ -0,0 +1,63 @@
|
|||
//===- InferTypeOpInterface.td - Infer Type interfaces -*- tablegen -----*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains a set of interfaces that can be used to define information
|
||||
// related to call-like and callable operations. Each of which are defined along
|
||||
// with the respective interface below.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef MLIR_INFERTYPEOPINTERFACE
|
||||
#else
|
||||
#define MLIR_INFERTYPEOPINTERFACE
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
// OpInterface to compute the return type of an operation. The arguments match
|
||||
// those in Operation::create with the exception that the location is optional
|
||||
// (if no location is provided, then the method will not emit an error on
|
||||
// mismatch).
|
||||
def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
|
||||
let description = [{
|
||||
Interface to access a registered method to infer the return types for an
|
||||
operation that could be used during op construction, verification or
|
||||
type inference.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{Returns the return types that an op would generate.
|
||||
|
||||
The method takes an optional location which, if set, will be used to
|
||||
report errors on. The operands and attributes correspond to those with
|
||||
which an Operation would be created (e.g., as used in Operation;:create).
|
||||
Regions are the nested regions of the op.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<Type, 2>",
|
||||
/*methodName=*/"inferReturnTypes",
|
||||
/*args=*/(ins "llvm::Optional<Location>":$location,
|
||||
"ArrayRef<Value*>":$operands,
|
||||
"ArrayRef<NamedAttribute>":$attributes,
|
||||
"ArrayRef<Region>":$regions)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // MLIR_INFERTYPEOPINTERFACE
|
|
@ -3,6 +3,7 @@ add_llvm_library(MLIRAnalysis STATIC
|
|||
AffineStructures.cpp
|
||||
CallGraph.cpp
|
||||
Dominance.cpp
|
||||
InferTypeOpInterface.cpp
|
||||
LoopAnalysis.cpp
|
||||
MemRefBoundCheck.cpp
|
||||
NestedMatcher.cpp
|
||||
|
@ -20,6 +21,7 @@ add_llvm_library(MLIRAnalysis STATIC
|
|||
add_dependencies(MLIRAnalysis
|
||||
MLIRAffineOps
|
||||
MLIRCallOpInterfacesIncGen
|
||||
MLIRTypeInferOpInterfaceIncGen
|
||||
MLIRLoopOps
|
||||
MLIRVectorOps
|
||||
)
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains the definitions of the infer op interfaces defined in
|
||||
// `InferTypeOpInterface.td`.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/InferTypeOpInterface.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "mlir/Analysis/InferTypeOpInterface.cpp.inc"
|
||||
} // namespace mlir
|
|
@ -216,6 +216,14 @@ OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
|
|||
return operand();
|
||||
}
|
||||
|
||||
SmallVector<Type, 2> mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
|
||||
llvm::Optional<Location> location, ArrayRef<Value *> operands,
|
||||
ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions) {
|
||||
if (location)
|
||||
mlir::emitError(*location) << "expected to fail";
|
||||
return SmallVector<Type, 2>{nullptr};
|
||||
}
|
||||
|
||||
// Static initialization for Test dialect registration.
|
||||
static mlir::DialectRegistration<mlir::TestDialect> testDialect;
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#define MLIR_TESTDIALECT_H
|
||||
|
||||
#include "mlir/Analysis/CallInterfaces.h"
|
||||
#include "mlir/Analysis/InferTypeOpInterface.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Analysis/CallInterfaces.td"
|
||||
include "mlir/Analysis/InferTypeOpInterface.td"
|
||||
|
||||
def TEST_Dialect : Dialect {
|
||||
let name = "test";
|
||||
|
@ -318,8 +319,7 @@ def ParentOp : TEST_Op<"parent">;
|
|||
def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>;
|
||||
|
||||
|
||||
def TerminatorOp : TEST_Op<"finish", [Terminator]> {
|
||||
}
|
||||
def TerminatorOp : TEST_Op<"finish", [Terminator]>;
|
||||
def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",
|
||||
[SingleBlockImplicitTerminator<"TerminatorOp">]> {
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
@ -329,6 +329,18 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
|
|||
let arguments = (ins I32ElementsAttr:$attr);
|
||||
}
|
||||
|
||||
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if",
|
||||
[InferTypeOpInterface]> {
|
||||
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
|
||||
let results = (outs AnyTensor:$res);
|
||||
// TODO(jpienaar): Remove the need to specify these here.
|
||||
let extraClassDeclaration = [{
|
||||
SmallVector<Type, 2> inferReturnTypes(llvm::Optional<Location> location,
|
||||
ArrayRef<Value*> operands, ArrayRef<NamedAttribute> attributes,
|
||||
ArrayRef<Region> regions);
|
||||
}];
|
||||
}
|
||||
|
||||
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
|
||||
|
||||
def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
|
||||
|
|
|
@ -51,6 +51,43 @@ struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
|
|||
static mlir::PassRegistration<TestPatternDriver>
|
||||
pass("test-patterns", "Run test dialect patterns");
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReturnType Driver.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct ReturnTypeOpMatch : public RewritePattern {
|
||||
ReturnTypeOpMatch(MLIRContext *ctx)
|
||||
: RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) {
|
||||
}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) {
|
||||
SmallVector<Value *, 4> values;
|
||||
values.reserve(op->getNumOperands());
|
||||
for (auto &operand : op->getOpOperands())
|
||||
values.push_back(operand.get());
|
||||
(void)retTypeFn.inferReturnTypes(op->getLoc(), values, op->getAttrs(),
|
||||
op->getRegions());
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
|
||||
void runOnFunction() override {
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
populateWithGenerated(&getContext(), &patterns);
|
||||
patterns.insert<ReturnTypeOpMatch>(&getContext());
|
||||
applyPatternsGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
static mlir::PassRegistration<TestReturnTypeDriver>
|
||||
rt_pass("test-return-type", "Run return type functions");
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Legalization Driver.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
// RUN: mlir-opt %s -test-return-type -split-input-file -verify-diagnostics | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: testReturnTypeOpInterface
|
||||
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
|
||||
// expected-error@+1 {{expected to fail}}
|
||||
%0 = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue