Code cleanups on Ch.4

This change performs general cleanups of the implementation of ch.4 and fixes some bugs. For example, the operations currently don't inherit from the shape inference interface.

PiperOrigin-RevId: 275089914
This commit is contained in:
River Riddle 2019-10-16 12:33:55 -07:00 committed by A. Unique TensorFlower
parent 3940b90d84
commit ab79c25d64
7 changed files with 121 additions and 92 deletions

View File

@ -26,6 +26,7 @@
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "toy/ShapeInferenceInterface.h"
namespace mlir { namespace mlir {
namespace toy { namespace toy {

View File

@ -92,7 +92,8 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
let verifier = [{ return ::verify(*this); }]; let verifier = [{ return ::verify(*this); }];
} }
def AddOp : Toy_Op<"add", [NoSideEffect]> { def AddOp : Toy_Op<"add",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise addition operation"; let summary = "element-wise addition operation";
let description = [{ let description = [{
The "add" operation performs element-wise addition between two tensors. The "add" operation performs element-wise addition between two tensors.
@ -108,12 +109,6 @@ def AddOp : Toy_Op<"add", [NoSideEffect]> {
buildAddOp(b, result, lhs, rhs); buildAddOp(b, result, lhs, rhs);
}] }]
>]; >];
let extraClassDeclaration = [{
void inferShapes() {
getResult()->setType(getOperand(0)->getType());
return;
}
}];
} }
def GenericCallOp : Toy_Op<"generic_call"> { def GenericCallOp : Toy_Op<"generic_call"> {
@ -150,7 +145,8 @@ def GenericCallOp : Toy_Op<"generic_call"> {
]; ];
} }
def MulOp : Toy_Op<"mul", [NoSideEffect]> { def MulOp : Toy_Op<"mul",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise multiplication operation"; let summary = "element-wise multiplication operation";
let description = [{ let description = [{
The "mul" operation performs element-wise multiplication between two The "mul" operation performs element-wise multiplication between two
@ -166,30 +162,6 @@ def MulOp : Toy_Op<"mul", [NoSideEffect]> {
buildMulOp(b, result, lhs, rhs); buildMulOp(b, result, lhs, rhs);
}] }]
>]; >];
let extraClassDeclaration = [{
void inferShapes() {
auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
auto lhsRank = lhs.getShape().size();
auto rhsRank = rhs.getShape().size();
if (lhsRank != rhsRank) {
return;
}
SmallVector<int64_t, 2> dims;
if (lhsRank == 1) {
// dot product, result shape is <1>
dims.push_back(1);
} else {
if (lhsRank != 2) {
return;
}
dims.push_back(lhs.getShape()[0]);
dims.push_back(rhs.getShape()[1]);
}
getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
return;
}
}];
} }
def PrintOp : Toy_Op<"print"> { def PrintOp : Toy_Op<"print"> {
@ -255,7 +227,8 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
let verifier = [{ return ::verify(*this); }]; let verifier = [{ return ::verify(*this); }];
} }
def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> { def TransposeOp : Toy_Op<"transpose",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "transpose operation"; let summary = "transpose operation";
let arguments = (ins F64Tensor:$input); let arguments = (ins F64Tensor:$input);
@ -268,18 +241,6 @@ def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
buildTransposeOp(b, result, input); buildTransposeOp(b, result, input);
}] }]
>]; >];
let extraClassDeclaration = [{
void inferShapes() {
SmallVector<int64_t, 2> dims;
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
dims.insert(dims.end(), arrayTy.getShape().begin(),
arrayTy.getShape().end());
if (dims.size() == 2)
std::swap(dims[0], dims[1]);
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
return;
}
}];
} }
#endif // TOY_OPS #endif // TOY_OPS

View File

@ -0,0 +1,37 @@
//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file contains the declarations of the shape inference interfaces defined
// in ShapeInferenceInterface.td.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
#include "mlir/IR/OpDefinition.h"
namespace mlir {
namespace toy {
/// Include the auto-generated declarations.
#include "toy/ShapeInferenceOpInterfaces.h.inc"
} // end namespace toy
} // end namespace mlir
#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_

View File

@ -30,8 +30,8 @@ include "mlir/IR/OpBase.td"
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
let methods = [ let methods = [
InterfaceMethod<"Infer output shape for the current operation.", InterfaceMethod<"Infer and set the output shape for the current operation.",
"void", "inferShapes", (ins), [{}]> "void", "inferShapes">
]; ];
} }

View File

@ -40,20 +40,27 @@
#include <algorithm> #include <algorithm>
namespace { namespace {
/// This is a simple function DCE pass that deletes all non-main functions after
/// inlining.
/// TODO(riverriddle) This is only necessary because MLIR currently does not
/// have generic DCE support for functions.
class DeadFunctionEliminationPass class DeadFunctionEliminationPass
: public mlir::ModulePass<DeadFunctionEliminationPass> { : public mlir::ModulePass<DeadFunctionEliminationPass> {
public: public:
void runOnModule() override { void runOnModule() override {
std::string str = "main"; mlir::ModuleOp module = getModule();
auto module = getModule(); mlir::SymbolTable moduleSymTable(module);
for (auto &f : module) {
// eliminate dead functions that are not main // Eliminate non-main functions.
if (str.find(f.getName().getStringRef()) == std::string::npos) auto mainFn = moduleSymTable.lookup<mlir::FuncOp>("main");
f.erase(); for (mlir::FuncOp func :
llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
if (func != mainFn)
func.erase();
} }
} }
}; };
} // namespace } // end anonymous namespace
/// Create a pass that eliminates inlined functions in toy. /// Create a pass that eliminates inlined functions in toy.
std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() { std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() {

View File

@ -126,6 +126,10 @@ static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
state.addOperands({lhs, rhs}); state.addOperands({lhs, rhs});
} }
/// Infer the output shape of the AddOp, this is required by the shape inference
/// interface.
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
static void buildGenericCallOp(mlir::Builder *builder, static void buildGenericCallOp(mlir::Builder *builder,
mlir::OperationState &state, StringRef callee, mlir::OperationState &state, StringRef callee,
ArrayRef<mlir::Value *> arguments) { ArrayRef<mlir::Value *> arguments) {
@ -141,6 +145,29 @@ static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
state.addOperands({lhs, rhs}); state.addOperands({lhs, rhs});
} }
/// Infer the output shape of the MulOp, this is required by the shape inference
/// interface.
void MulOp::inferShapes() {
auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
auto lhsRank = lhs.getShape().size();
auto rhsRank = rhs.getShape().size();
if (lhsRank != rhsRank)
return;
SmallVector<int64_t, 2> dims;
if (lhsRank == 1) {
// dot product, result shape is <1>
dims.push_back(1);
} else if (lhsRank == 2) {
dims.push_back(lhs.getShape()[0]);
dims.push_back(rhs.getShape()[1]);
} else {
return;
}
getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
}
static mlir::LogicalResult verify(ReturnOp op) { static mlir::LogicalResult verify(ReturnOp op) {
// We know that the parent operation is a function, because of the 'HasParent' // We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition. // trait attached to the operation definition.
@ -182,6 +209,15 @@ static void buildTransposeOp(mlir::Builder *builder,
state.addOperands(value); state.addOperands(value);
} }
void TransposeOp::inferShapes() {
SmallVector<int64_t, 2> dims;
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
dims.insert(dims.end(), arrayTy.getShape().begin(), arrayTy.getShape().end());
if (dims.size() == 2)
std::swap(dims[0], dims[1]);
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -23,67 +23,47 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "toy/Dialect.h" #include "toy/Dialect.h"
#include "toy/Passes.h" #include "toy/Passes.h"
#include "llvm/ADT/STLExtras.h" #include "toy/ShapeInferenceInterface.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <algorithm>
#define DEBUG_TYPE "shape-inference" #define DEBUG_TYPE "shape-inference"
using llvm::MutableArrayRef;
using llvm::raw_ostream;
using llvm::SmallVector;
using llvm::SmallVectorImpl;
using llvm::StringRef;
using llvm::Twine;
using namespace mlir; using namespace mlir;
using namespace toy;
namespace { /// Include the auto-generated definitions for the shape inference interfaces.
// clang-format off
#include "toy/ShapeInferenceOpInterfaces.h.inc"
#include "toy/ShapeInferenceOpInterfaces.cpp.inc" #include "toy/ShapeInferenceOpInterfaces.cpp.inc"
namespace {
/// The ShapeInferencePass is a FunctionPass that performs intra-procedural /// The ShapeInferencePass is a FunctionPass that performs intra-procedural
/// shape inference. /// shape inference.
/// ///
/// Algorithm: /// Algorithm:
/// ///
/// 1) Build a worklist containing all the operations that are returning /// 1) Build a worklist containing all the operations that return a
/// a generic Toy array: these are the operations that need shape /// dynamically shaped tensor: these are the operations that need shape
/// inference. /// inference.
/// 2) Iterate on the worklist: /// 2) Iterate on the worklist:
/// a) find an operation to process: the next ready operation in the /// a) find an operation to process: the next ready operation in the
/// worklist has all of its arguments non-generic, /// worklist has all of its arguments non-generic,
/// b) if no operation is found, break out of the loop, /// b) if no operation is found, break out of the loop,
/// c) remove the operation from the worklist, /// c) remove the operation from the worklist,
/// d) infer the shape of its output from the arguments type. /// d) infer the shape of its output from the argument types.
/// 3) If the worklist is empty, the algorithm succeeded and we infer the /// 3) If the worklist is empty, the algorithm succeeded.
/// return type for the function from the return operation.
/// ///
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
public: public:
bool returnsGenericArray(Operation *op) {
if (op->getNumResults() == 1) {
if (!op->getResult(0)->getType().isa<ShapedType>())
return true;
}
return false;
}
void runOnFunction() override { void runOnFunction() override {
auto f = getFunction(); auto f = getFunction();
// Populate the worklist with the operations that need shape inference: // Populate the worklist with the operations that need shape inference:
// these are operations that return a generic array. // these are operations that return a dynamic shape.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist; llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
f.walk([&](mlir::Operation *op) { f.walk([&](mlir::Operation *op) {
if (returnsGenericArray(op)) { if (returnsDynamicShape(op))
opWorklist.insert(op); opWorklist.insert(op);
}
}); });
// Iterate on the operations in the worklist until all operations have been // Iterate on the operations in the worklist until all operations have been
@ -91,15 +71,14 @@ public:
while (!opWorklist.empty()) { while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation // Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic). // with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, [this](Operation *op) { auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
return this->returnsGenericArray(op);
});
if (nextop == opWorklist.end()) if (nextop == opWorklist.end())
break; // failure: no operations can be inferred. break;
Operation *op = *nextop; Operation *op = *nextop;
opWorklist.erase(op); opWorklist.erase(op);
// Ask the operation to infer its output shapes.
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
auto shapeOp = dyn_cast<ShapeInference>(op); auto shapeOp = dyn_cast<ShapeInference>(op);
shapeOp.inferShapes(); shapeOp.inferShapes();
@ -107,11 +86,19 @@ public:
// If the operation worklist isn't empty, this indicates a failure. // If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) { if (!opWorklist.empty()) {
signalPassFailure(); f.emitError("Shape inference failed, ")
auto diag = f.emitError("Shape inference failed, ")
<< opWorklist.size() << " operations couldn't be inferred\n"; << opWorklist.size() << " operations couldn't be inferred\n";
signalPassFailure();
} }
} }
/// A utility method that returns if the given operation has a dynamically
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
return !resultType.isa<RankedTensorType>();
});
}
}; };
} // end anonymous namespace } // end anonymous namespace