diff --git a/mlir/include/mlir/Analysis/CMakeLists.txt b/mlir/include/mlir/Analysis/CMakeLists.txt index 619f4b124c7c..3d9a7ed36979 100644 --- a/mlir/include/mlir/Analysis/CMakeLists.txt +++ b/mlir/include/mlir/Analysis/CMakeLists.txt @@ -2,3 +2,8 @@ set(LLVM_TARGET_DEFINITIONS CallInterfaces.td) mlir_tablegen(CallInterfaces.h.inc -gen-op-interface-decls) mlir_tablegen(CallInterfaces.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(MLIRCallOpInterfacesIncGen) + +set(LLVM_TARGET_DEFINITIONS InferTypeOpInterface.td) +mlir_tablegen(InferTypeOpInterface.h.inc -gen-op-interface-decls) +mlir_tablegen(InferTypeOpInterface.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRTypeInferOpInterfaceIncGen) diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.h b/mlir/include/mlir/Analysis/InferTypeOpInterface.h new file mode 100644 index 000000000000..b80723e45f13 --- /dev/null +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.h @@ -0,0 +1,40 @@ +//===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains the definitions of the infer op interfaces defined in +// `InferTypeOpInterface.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_ +#define MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +#include "mlir/Analysis/InferTypeOpInterface.h.inc" + +} // namespace mlir + +#endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/mlir/include/mlir/Analysis/InferTypeOpInterface.td new file mode 100644 index 000000000000..a155810b081c --- /dev/null +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.td @@ -0,0 +1,63 @@ +//===- InferTypeOpInterface.td - Infer Type interfaces -*- tablegen -----*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains a set of interfaces that can be used to define information +// related to call-like and callable operations. Each of which are defined along +// with the respective interface below. +// +//===----------------------------------------------------------------------===// + +#ifdef MLIR_INFERTYPEOPINTERFACE +#else +#define MLIR_INFERTYPEOPINTERFACE + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +// OpInterface to compute the return type of an operation. The arguments match +// those in Operation::create with the exception that the location is optional +// (if no location is provided, then the method will not emit an error on +// mismatch). +def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that could be used during op construction, verification or + type inference. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{Returns the return types that an op would generate. + + The method takes an optional location which, if set, will be used to + report errors on. The operands and attributes correspond to those with + which an Operation would be created (e.g., as used in Operation;:create). + Regions are the nested regions of the op. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"inferReturnTypes", + /*args=*/(ins "llvm::Optional":$location, + "ArrayRef":$operands, + "ArrayRef":$attributes, + "ArrayRef":$regions) + >, + ]; +} + +#endif // MLIR_INFERTYPEOPINTERFACE diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index dff3df9a016f..c16ad3f3d1f2 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -3,6 +3,7 @@ add_llvm_library(MLIRAnalysis STATIC AffineStructures.cpp CallGraph.cpp Dominance.cpp + InferTypeOpInterface.cpp LoopAnalysis.cpp MemRefBoundCheck.cpp NestedMatcher.cpp @@ -20,6 +21,7 @@ add_llvm_library(MLIRAnalysis STATIC add_dependencies(MLIRAnalysis MLIRAffineOps MLIRCallOpInterfacesIncGen + MLIRTypeInferOpInterfaceIncGen MLIRLoopOps MLIRVectorOps ) diff --git a/mlir/lib/Analysis/InferTypeOpInterface.cpp b/mlir/lib/Analysis/InferTypeOpInterface.cpp new file mode 100644 index 000000000000..cbbd44681bac --- /dev/null +++ b/mlir/lib/Analysis/InferTypeOpInterface.cpp @@ -0,0 +1,31 @@ +//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains the definitions of the infer op interfaces defined in +// `InferTypeOpInterface.td`. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/InferTypeOpInterface.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +#include "mlir/Analysis/InferTypeOpInterface.cpp.inc" +} // namespace mlir diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 69c3bcdb21ed..d91bb1a2f576 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -216,6 +216,14 @@ OpFoldResult TestOpWithRegionFold::fold(ArrayRef operands) { return operand(); } +SmallVector mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( + llvm::Optional location, ArrayRef operands, + ArrayRef attributes, ArrayRef regions) { + if (location) + mlir::emitError(*location) << "expected to fail"; + return SmallVector{nullptr}; +} + // Static initialization for Test dialect registration. static mlir::DialectRegistration testDialect; diff --git a/mlir/test/lib/TestDialect/TestDialect.h b/mlir/test/lib/TestDialect/TestDialect.h index a2fcecab718f..ffe2a1c50ec7 100644 --- a/mlir/test/lib/TestDialect/TestDialect.h +++ b/mlir/test/lib/TestDialect/TestDialect.h @@ -24,6 +24,7 @@ #define MLIR_TESTDIALECT_H #include "mlir/Analysis/CallInterfaces.h" +#include "mlir/Analysis/InferTypeOpInterface.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 0c2e53b2bca4..72991ced497b 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -21,6 +21,7 @@ include "mlir/IR/OpBase.td" include "mlir/Analysis/CallInterfaces.td" +include "mlir/Analysis/InferTypeOpInterface.td" def TEST_Dialect : Dialect { let name = "test"; @@ -318,8 +319,7 @@ def ParentOp : TEST_Op<"parent">; def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>; -def TerminatorOp : TEST_Op<"finish", [Terminator]> { -} +def TerminatorOp : TEST_Op<"finish", [Terminator]>; def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator", [SingleBlockImplicitTerminator<"TerminatorOp">]> { let regions = (region SizedRegion<1>:$region); @@ -329,6 +329,18 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> { let arguments = (ins I32ElementsAttr:$attr); } +def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", + [InferTypeOpInterface]> { + let arguments = (ins AnyTensor:$x, AnyTensor:$y); + let results = (outs AnyTensor:$res); + // TODO(jpienaar): Remove the need to specify these here. + let extraClassDeclaration = [{ + SmallVector inferReturnTypes(llvm::Optional location, + ArrayRef operands, ArrayRef attributes, + ArrayRef regions); + }]; +} + def IsNotScalar : Constraint>; def UpdateAttr : Pat<(I32ElementsAttrOp $attr), diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 533ec1fe0ac0..17a257f6669a 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -51,6 +51,43 @@ struct TestPatternDriver : public FunctionPass { static mlir::PassRegistration pass("test-patterns", "Run test dialect patterns"); +//===----------------------------------------------------------------------===// +// ReturnType Driver. +//===----------------------------------------------------------------------===// + +struct ReturnTypeOpMatch : public RewritePattern { + ReturnTypeOpMatch(MLIRContext *ctx) + : RewritePattern(OpWithInferTypeInterfaceOp::getOperationName(), 1, ctx) { + } + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + if (auto retTypeFn = dyn_cast(op)) { + SmallVector values; + values.reserve(op->getNumOperands()); + for (auto &operand : op->getOpOperands()) + values.push_back(operand.get()); + (void)retTypeFn.inferReturnTypes(op->getLoc(), values, op->getAttrs(), + op->getRegions()); + } + return matchFailure(); + } +}; + +namespace { +struct TestReturnTypeDriver : public FunctionPass { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + populateWithGenerated(&getContext(), &patterns); + patterns.insert(&getContext()); + applyPatternsGreedily(getFunction(), patterns); + } +}; +} // end anonymous namespace + +static mlir::PassRegistration + rt_pass("test-return-type", "Run return type functions"); + //===----------------------------------------------------------------------===// // Legalization Driver. //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir new file mode 100644 index 000000000000..f203677546ee --- /dev/null +++ b/mlir/test/mlir-tblgen/return-types.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-opt %s -test-return-type -split-input-file -verify-diagnostics | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: testReturnTypeOpInterface +func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) { + // expected-error@+1 {{expected to fail}} + %0 = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32> + return +}