From ab79c25d646ed7ef214b19042d49f15425c49818 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 16 Oct 2019 12:33:55 -0700 Subject: [PATCH] Code cleanups on Ch.4 This change performs general cleanups of the implementation of ch.4 and fixes some bugs. For example, the operations currently don't inherit from the shape inference interface. PiperOrigin-RevId: 275089914 --- mlir/examples/toy/Ch4/include/toy/Dialect.h | 1 + mlir/examples/toy/Ch4/include/toy/Ops.td | 51 ++------------- .../Ch4/include/toy/ShapeInferenceInterface.h | 37 +++++++++++ .../include/toy/ShapeInferenceInterface.td | 4 +- .../Ch4/mlir/DeadFunctionEliminationPass.cpp | 21 ++++--- mlir/examples/toy/Ch4/mlir/Dialect.cpp | 36 +++++++++++ .../toy/Ch4/mlir/ShapeInferencePass.cpp | 63 ++++++++----------- 7 files changed, 121 insertions(+), 92 deletions(-) create mode 100644 mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h diff --git a/mlir/examples/toy/Ch4/include/toy/Dialect.h b/mlir/examples/toy/Ch4/include/toy/Dialect.h index da61191c6c0f..556ae972b84e 100644 --- a/mlir/examples/toy/Ch4/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch4/include/toy/Dialect.h @@ -26,6 +26,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/StandardTypes.h" +#include "toy/ShapeInferenceInterface.h" namespace mlir { namespace toy { diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td index f0140d70f9bd..a8c67592874d 100644 --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -92,7 +92,8 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { let verifier = [{ return ::verify(*this); }]; } -def AddOp : Toy_Op<"add", [NoSideEffect]> { +def AddOp : Toy_Op<"add", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "element-wise addition operation"; let description = [{ The "add" operation performs element-wise addition between two tensors. @@ -108,12 +109,6 @@ def AddOp : Toy_Op<"add", [NoSideEffect]> { buildAddOp(b, result, lhs, rhs); }] >]; - let extraClassDeclaration = [{ - void inferShapes() { - getResult()->setType(getOperand(0)->getType()); - return; - } - }]; } def GenericCallOp : Toy_Op<"generic_call"> { @@ -150,7 +145,8 @@ def GenericCallOp : Toy_Op<"generic_call"> { ]; } -def MulOp : Toy_Op<"mul", [NoSideEffect]> { +def MulOp : Toy_Op<"mul", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "element-wise multiplication operation"; let description = [{ The "mul" operation performs element-wise multiplication between two @@ -166,30 +162,6 @@ def MulOp : Toy_Op<"mul", [NoSideEffect]> { buildMulOp(b, result, lhs, rhs); }] >]; - let extraClassDeclaration = [{ - void inferShapes() { - auto lhs = getOperand(0)->getType().cast(); - auto rhs = getOperand(1)->getType().cast(); - auto lhsRank = lhs.getShape().size(); - auto rhsRank = rhs.getShape().size(); - if (lhsRank != rhsRank) { - return; - } - SmallVector dims; - if (lhsRank == 1) { - // dot product, result shape is <1> - dims.push_back(1); - } else { - if (lhsRank != 2) { - return; - } - dims.push_back(lhs.getShape()[0]); - dims.push_back(rhs.getShape()[1]); - } - getResult()->setType(RankedTensorType::get(dims, lhs.getElementType())); - return; - } - }]; } def PrintOp : Toy_Op<"print"> { @@ -255,7 +227,8 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { let verifier = [{ return ::verify(*this); }]; } -def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> { +def TransposeOp : Toy_Op<"transpose", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "transpose operation"; let arguments = (ins F64Tensor:$input); @@ -268,18 +241,6 @@ def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> { buildTransposeOp(b, result, input); }] >]; - let extraClassDeclaration = [{ - void inferShapes() { - SmallVector dims; - auto arrayTy = getOperand()->getType().cast(); - dims.insert(dims.end(), arrayTy.getShape().begin(), - arrayTy.getShape().end()); - if (dims.size() == 2) - std::swap(dims[0], dims[1]); - getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); - return; - } - }]; } #endif // TOY_OPS diff --git a/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h new file mode 100644 index 000000000000..fc36b5b100dd --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h @@ -0,0 +1,37 @@ +//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace toy { + +/// Include the auto-generated declarations. +#include "toy/ShapeInferenceOpInterfaces.h.inc" + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td index 2040cc44fdf4..4b1240d28d57 100644 --- a/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td +++ b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td @@ -30,8 +30,8 @@ include "mlir/IR/OpBase.td" def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { let methods = [ - InterfaceMethod<"Infer output shape for the current operation.", - "void", "inferShapes", (ins), [{}]> + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> ]; } diff --git a/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp index e7e64ce5b3d4..b58adb5d52fd 100644 --- a/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp +++ b/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp @@ -40,20 +40,27 @@ #include namespace { +/// This is a simple function DCE pass that deletes all non-main functions after +/// inlining. +/// TODO(riverriddle) This is only necessary because MLIR currently does not +/// have generic DCE support for functions. class DeadFunctionEliminationPass : public mlir::ModulePass { public: void runOnModule() override { - std::string str = "main"; - auto module = getModule(); - for (auto &f : module) { - // eliminate dead functions that are not main - if (str.find(f.getName().getStringRef()) == std::string::npos) - f.erase(); + mlir::ModuleOp module = getModule(); + mlir::SymbolTable moduleSymTable(module); + + // Eliminate non-main functions. + auto mainFn = moduleSymTable.lookup("main"); + for (mlir::FuncOp func : + llvm::make_early_inc_range(module.getOps())) { + if (func != mainFn) + func.erase(); } } }; -} // namespace +} // end anonymous namespace /// Create a pass that eliminates inlined functions in toy. std::unique_ptr mlir::toy::createDeadFunctionEliminationPass() { diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index 63eee4eefb8b..e285fac13a5f 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -126,6 +126,10 @@ static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state, state.addOperands({lhs, rhs}); } +/// Infer the output shape of the AddOp, this is required by the shape inference +/// interface. +void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } + static void buildGenericCallOp(mlir::Builder *builder, mlir::OperationState &state, StringRef callee, ArrayRef arguments) { @@ -141,6 +145,29 @@ static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state, state.addOperands({lhs, rhs}); } +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { + auto lhs = getOperand(0)->getType().cast(); + auto rhs = getOperand(1)->getType().cast(); + auto lhsRank = lhs.getShape().size(); + auto rhsRank = rhs.getShape().size(); + if (lhsRank != rhsRank) + return; + + SmallVector dims; + if (lhsRank == 1) { + // dot product, result shape is <1> + dims.push_back(1); + } else if (lhsRank == 2) { + dims.push_back(lhs.getShape()[0]); + dims.push_back(rhs.getShape()[1]); + } else { + return; + } + getResult()->setType(RankedTensorType::get(dims, lhs.getElementType())); +} + static mlir::LogicalResult verify(ReturnOp op) { // We know that the parent operation is a function, because of the 'HasParent' // trait attached to the operation definition. @@ -182,6 +209,15 @@ static void buildTransposeOp(mlir::Builder *builder, state.addOperands(value); } +void TransposeOp::inferShapes() { + SmallVector dims; + auto arrayTy = getOperand()->getType().cast(); + dims.insert(dims.end(), arrayTy.getShape().begin(), arrayTy.getShape().end()); + if (dims.size() == 2) + std::swap(dims[0], dims[1]); + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index b8b091a62c5e..5acf8f9e394a 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -23,67 +23,47 @@ #include "mlir/Pass/Pass.h" #include "toy/Dialect.h" #include "toy/Passes.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSet.h" +#include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include #define DEBUG_TYPE "shape-inference" -using llvm::MutableArrayRef; -using llvm::raw_ostream; -using llvm::SmallVector; -using llvm::SmallVectorImpl; -using llvm::StringRef; -using llvm::Twine; using namespace mlir; +using namespace toy; -namespace { - -// clang-format off -#include "toy/ShapeInferenceOpInterfaces.h.inc" +/// Include the auto-generated definitions for the shape inference interfaces. #include "toy/ShapeInferenceOpInterfaces.cpp.inc" +namespace { /// The ShapeInferencePass is a FunctionPass that performs intra-procedural /// shape inference. /// /// Algorithm: /// -/// 1) Build a worklist containing all the operations that are returning -/// a generic Toy array: these are the operations that need shape +/// 1) Build a worklist containing all the operations that return a +/// dynamically shaped tensor: these are the operations that need shape /// inference. /// 2) Iterate on the worklist: /// a) find an operation to process: the next ready operation in the /// worklist has all of its arguments non-generic, /// b) if no operation is found, break out of the loop, /// c) remove the operation from the worklist, -/// d) infer the shape of its output from the arguments type. -/// 3) If the worklist is empty, the algorithm succeeded and we infer the -/// return type for the function from the return operation. +/// d) infer the shape of its output from the argument types. +/// 3) If the worklist is empty, the algorithm succeeded. /// class ShapeInferencePass : public mlir::FunctionPass { public: - bool returnsGenericArray(Operation *op) { - if (op->getNumResults() == 1) { - if (!op->getResult(0)->getType().isa()) - return true; - } - return false; - } - void runOnFunction() override { auto f = getFunction(); // Populate the worklist with the operations that need shape inference: - // these are operations that return a generic array. + // these are operations that return a dynamic shape. llvm::SmallPtrSet opWorklist; f.walk([&](mlir::Operation *op) { - if (returnsGenericArray(op)) { + if (returnsDynamicShape(op)) opWorklist.insert(op); - } }); // Iterate on the operations in the worklist until all operations have been @@ -91,15 +71,14 @@ public: while (!opWorklist.empty()) { // Find the next operation ready for inference, that is an operation // with all operands already resolved (non-generic). - auto nextop = llvm::find_if(opWorklist, [this](Operation *op) { - return this->returnsGenericArray(op); - }); - + auto nextop = llvm::find_if(opWorklist, returnsDynamicShape); if (nextop == opWorklist.end()) - break; // failure: no operations can be inferred. + break; Operation *op = *nextop; opWorklist.erase(op); + + // Ask the operation to infer its output shapes. LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); auto shapeOp = dyn_cast(op); shapeOp.inferShapes(); @@ -107,11 +86,19 @@ public: // If the operation worklist isn't empty, this indicates a failure. if (!opWorklist.empty()) { + f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; signalPassFailure(); - auto diag = f.emitError("Shape inference failed, ") - << opWorklist.size() << " operations couldn't be inferred\n"; } } + + /// A utility method that returns if the given operation has a dynamically + /// shaped result. + static bool returnsDynamicShape(Operation *op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !resultType.isa(); + }); + } }; } // end anonymous namespace