Add InferTypeOpTrait & enable generating its member function definition

Use OpInterfaces to add an interface for ops defining a return type function.

This change does not use this trait in any meaningful way, I'll use it in a
follow up to generalize and unify some of the op type traits/constraints. Also,
currently the infer type function can only be manually specified in C++, that should rather be the fallback in future.

PiperOrigin-RevId: 271883746
This commit is contained in:
Jacques Pienaar 2019-09-29 17:28:29 -07:00 committed by A. Unique TensorFlower
parent f45a392566
commit e5a43186d3
10 changed files with 209 additions and 2 deletions

View File

@ -2,3 +2,8 @@ set(LLVM_TARGET_DEFINITIONS CallInterfaces.td)
mlir_tablegen(CallInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(CallInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRCallOpInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS InferTypeOpInterface.td)
mlir_tablegen(InferTypeOpInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(InferTypeOpInterface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRTypeInferOpInterfaceIncGen)

View File

@ -0,0 +1,40 @@
//===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file contains the definitions of the infer op interfaces defined in
// `InferTypeOpInterface.td`.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
#define MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
#include "mlir/Analysis/InferTypeOpInterface.h.inc"
} // namespace mlir
#endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_

View File

@ -0,0 +1,63 @@
//===- InferTypeOpInterface.td - Infer Type interfaces -*- tablegen -----*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file contains a set of interfaces that can be used to define information
// related to call-like and callable operations. Each of which are defined along
// with the respective interface below.
//
//===----------------------------------------------------------------------===//
#ifdef MLIR_INFERTYPEOPINTERFACE
#else
#define MLIR_INFERTYPEOPINTERFACE
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
// OpInterface to compute the return type of an operation. The arguments match
// those in Operation::create with the exception that the location is optional
// (if no location is provided, then the method will not emit an error on
// mismatch).
def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that could be used during op construction, verification or
type inference.
}];
let methods = [
InterfaceMethod<
/*desc=*/[{Returns the return types that an op would generate.
The method takes an optional location which, if set, will be used to
report errors on. The operands and attributes correspond to those with
which an Operation would be created (e.g., as used in Operation;:create).
Regions are the nested regions of the op.
}],
/*retTy=*/"SmallVector<Type, 2>",
/*methodName=*/"inferReturnTypes",
/*args=*/(ins "llvm::Optional<Location>":$location,
"ArrayRef<Value*>":$operands,
"ArrayRef<NamedAttribute>":$attributes,
"ArrayRef<Region>":$regions)
>,
];
}
#endif // MLIR_INFERTYPEOPINTERFACE

View File

@ -3,6 +3,7 @@ add_llvm_library(MLIRAnalysis STATIC
AffineStructures.cpp
CallGraph.cpp
Dominance.cpp
InferTypeOpInterface.cpp
LoopAnalysis.cpp
MemRefBoundCheck.cpp
NestedMatcher.cpp
@ -20,6 +21,7 @@ add_llvm_library(MLIRAnalysis STATIC
add_dependencies(MLIRAnalysis
MLIRAffineOps
MLIRCallOpInterfacesIncGen
MLIRTypeInferOpInterfaceIncGen
MLIRLoopOps
MLIRVectorOps
)

View File

@ -0,0 +1,31 @@
//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file contains the definitions of the infer op interfaces defined in
// `InferTypeOpInterface.td`.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/InferTypeOpInterface.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
#include "mlir/Analysis/InferTypeOpInterface.cpp.inc"
} // namespace mlir

View File

@ -216,6 +216,14 @@ OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
return operand();
}
SmallVector<Type, 2> mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
llvm::Optional<Location> location, ArrayRef<Value *> operands,
ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions) {
if (location)
mlir::emitError(*location) << "expected to fail";
return SmallVector<Type, 2>{nullptr};
}
// Static initialization for Test dialect registration.
static mlir::DialectRegistration<mlir::TestDialect> testDialect;

View File

@ -24,6 +24,7 @@
#define MLIR_TESTDIALECT_H
#include "mlir/Analysis/CallInterfaces.h"
#include "mlir/Analysis/InferTypeOpInterface.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"

View File

@ -21,6 +21,7 @@
include "mlir/IR/OpBase.td"
include "mlir/Analysis/CallInterfaces.td"
include "mlir/Analysis/InferTypeOpInterface.td"
def TEST_Dialect : Dialect {
let name = "test";
@ -318,8 +319,7 @@ def ParentOp : TEST_Op<"parent">;
def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>;
def TerminatorOp : TEST_Op<"finish", [Terminator]> {
}
def TerminatorOp : TEST_Op<"finish", [Terminator]>;
def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",
[SingleBlockImplicitTerminator<"TerminatorOp">]> {
let regions = (region SizedRegion<1>:$region);
@ -329,6 +329,18 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
let arguments = (ins I32ElementsAttr:$attr);
}
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if",
[InferTypeOpInterface]> {
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
let results = (outs AnyTensor:$res);
// TODO(jpienaar): Remove the need to specify these here.
let extraClassDeclaration = [{
SmallVector<Type, 2> inferReturnTypes(llvm::Optional<Location> location,
ArrayRef<Value*> operands, ArrayRef<NamedAttribute> attributes,
ArrayRef<Region> regions);
}];
}
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
def UpdateAttr : Pat<(I32ElementsAttrOp $attr),

View File

@ -51,6 +51,43 @@ struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
static mlir::PassRegistration<TestPatternDriver>
pass("test-patterns", "Run test dialect patterns");
//===----------------------------------------------------------------------===//
// ReturnType Driver.
//===----------------------------------------------------------------------===//
struct ReturnTypeOpMatch : public RewritePattern {
ReturnTypeOpMatch(MLIRContext *ctx)
: RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) {
}
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) {
SmallVector<Value *, 4> values;
values.reserve(op->getNumOperands());
for (auto &operand : op->getOpOperands())
values.push_back(operand.get());
(void)retTypeFn.inferReturnTypes(op->getLoc(), values, op->getAttrs(),
op->getRegions());
}
return matchFailure();
}
};
namespace {
struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
void runOnFunction() override {
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
patterns.insert<ReturnTypeOpMatch>(&getContext());
applyPatternsGreedily(getFunction(), patterns);
}
};
} // end anonymous namespace
static mlir::PassRegistration<TestReturnTypeDriver>
rt_pass("test-return-type", "Run return type functions");
//===----------------------------------------------------------------------===//
// Legalization Driver.
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,8 @@
// RUN: mlir-opt %s -test-return-type -split-input-file -verify-diagnostics | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: testReturnTypeOpInterface
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
// expected-error@+1 {{expected to fail}}
%0 = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
return
}