forked from OSchip/llvm-project
Code cleanups on Ch.4
This change performs general cleanups of the implementation of ch.4 and fixes some bugs. For example, the operations currently don't inherit from the shape inference interface. PiperOrigin-RevId: 275089914
This commit is contained in:
parent
3940b90d84
commit
ab79c25d64
|
@ -26,6 +26,7 @@
|
|||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "toy/ShapeInferenceInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace toy {
|
||||
|
|
|
@ -92,7 +92,8 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
|||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def AddOp : Toy_Op<"add", [NoSideEffect]> {
|
||||
def AddOp : Toy_Op<"add",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "element-wise addition operation";
|
||||
let description = [{
|
||||
The "add" operation performs element-wise addition between two tensors.
|
||||
|
@ -108,12 +109,6 @@ def AddOp : Toy_Op<"add", [NoSideEffect]> {
|
|||
buildAddOp(b, result, lhs, rhs);
|
||||
}]
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
void inferShapes() {
|
||||
getResult()->setType(getOperand(0)->getType());
|
||||
return;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def GenericCallOp : Toy_Op<"generic_call"> {
|
||||
|
@ -150,7 +145,8 @@ def GenericCallOp : Toy_Op<"generic_call"> {
|
|||
];
|
||||
}
|
||||
|
||||
def MulOp : Toy_Op<"mul", [NoSideEffect]> {
|
||||
def MulOp : Toy_Op<"mul",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "element-wise multiplication operation";
|
||||
let description = [{
|
||||
The "mul" operation performs element-wise multiplication between two
|
||||
|
@ -166,30 +162,6 @@ def MulOp : Toy_Op<"mul", [NoSideEffect]> {
|
|||
buildMulOp(b, result, lhs, rhs);
|
||||
}]
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
void inferShapes() {
|
||||
auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
auto lhsRank = lhs.getShape().size();
|
||||
auto rhsRank = rhs.getShape().size();
|
||||
if (lhsRank != rhsRank) {
|
||||
return;
|
||||
}
|
||||
SmallVector<int64_t, 2> dims;
|
||||
if (lhsRank == 1) {
|
||||
// dot product, result shape is <1>
|
||||
dims.push_back(1);
|
||||
} else {
|
||||
if (lhsRank != 2) {
|
||||
return;
|
||||
}
|
||||
dims.push_back(lhs.getShape()[0]);
|
||||
dims.push_back(rhs.getShape()[1]);
|
||||
}
|
||||
getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
|
||||
return;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def PrintOp : Toy_Op<"print"> {
|
||||
|
@ -255,7 +227,8 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
|
|||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
|
||||
def TransposeOp : Toy_Op<"transpose",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "transpose operation";
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
|
@ -268,18 +241,6 @@ def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
|
|||
buildTransposeOp(b, result, input);
|
||||
}]
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
void inferShapes() {
|
||||
SmallVector<int64_t, 2> dims;
|
||||
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||
dims.insert(dims.end(), arrayTy.getShape().begin(),
|
||||
arrayTy.getShape().end());
|
||||
if (dims.size() == 2)
|
||||
std::swap(dims[0], dims[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
return;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TOY_OPS
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains the declarations of the shape inference interfaces defined
|
||||
// in ShapeInferenceInterface.td.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
|
||||
#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace toy {
|
||||
|
||||
/// Include the auto-generated declarations.
|
||||
#include "toy/ShapeInferenceOpInterfaces.h.inc"
|
||||
|
||||
} // end namespace toy
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
|
|
@ -30,8 +30,8 @@ include "mlir/IR/OpBase.td"
|
|||
|
||||
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
|
||||
let methods = [
|
||||
InterfaceMethod<"Infer output shape for the current operation.",
|
||||
"void", "inferShapes", (ins), [{}]>
|
||||
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||
"void", "inferShapes">
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
@ -40,20 +40,27 @@
|
|||
#include <algorithm>
|
||||
|
||||
namespace {
|
||||
/// This is a simple function DCE pass that deletes all non-main functions after
|
||||
/// inlining.
|
||||
/// TODO(riverriddle) This is only necessary because MLIR currently does not
|
||||
/// have generic DCE support for functions.
|
||||
class DeadFunctionEliminationPass
|
||||
: public mlir::ModulePass<DeadFunctionEliminationPass> {
|
||||
public:
|
||||
void runOnModule() override {
|
||||
std::string str = "main";
|
||||
auto module = getModule();
|
||||
for (auto &f : module) {
|
||||
// eliminate dead functions that are not main
|
||||
if (str.find(f.getName().getStringRef()) == std::string::npos)
|
||||
f.erase();
|
||||
mlir::ModuleOp module = getModule();
|
||||
mlir::SymbolTable moduleSymTable(module);
|
||||
|
||||
// Eliminate non-main functions.
|
||||
auto mainFn = moduleSymTable.lookup<mlir::FuncOp>("main");
|
||||
for (mlir::FuncOp func :
|
||||
llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
|
||||
if (func != mainFn)
|
||||
func.erase();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Create a pass that eliminates inlined functions in toy.
|
||||
std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() {
|
||||
|
|
|
@ -126,6 +126,10 @@ static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
|
|||
state.addOperands({lhs, rhs});
|
||||
}
|
||||
|
||||
/// Infer the output shape of the AddOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||
|
||||
static void buildGenericCallOp(mlir::Builder *builder,
|
||||
mlir::OperationState &state, StringRef callee,
|
||||
ArrayRef<mlir::Value *> arguments) {
|
||||
|
@ -141,6 +145,29 @@ static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
|
|||
state.addOperands({lhs, rhs});
|
||||
}
|
||||
|
||||
/// Infer the output shape of the MulOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void MulOp::inferShapes() {
|
||||
auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
auto lhsRank = lhs.getShape().size();
|
||||
auto rhsRank = rhs.getShape().size();
|
||||
if (lhsRank != rhsRank)
|
||||
return;
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
if (lhsRank == 1) {
|
||||
// dot product, result shape is <1>
|
||||
dims.push_back(1);
|
||||
} else if (lhsRank == 2) {
|
||||
dims.push_back(lhs.getShape()[0]);
|
||||
dims.push_back(rhs.getShape()[1]);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
|
||||
}
|
||||
|
||||
static mlir::LogicalResult verify(ReturnOp op) {
|
||||
// We know that the parent operation is a function, because of the 'HasParent'
|
||||
// trait attached to the operation definition.
|
||||
|
@ -182,6 +209,15 @@ static void buildTransposeOp(mlir::Builder *builder,
|
|||
state.addOperands(value);
|
||||
}
|
||||
|
||||
void TransposeOp::inferShapes() {
|
||||
SmallVector<int64_t, 2> dims;
|
||||
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||
dims.insert(dims.end(), arrayTy.getShape().begin(), arrayTy.getShape().end());
|
||||
if (dims.size() == 2)
|
||||
std::swap(dims[0], dims[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -23,67 +23,47 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
#include "toy/Dialect.h"
|
||||
#include "toy/Passes.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "toy/ShapeInferenceInterface.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <algorithm>
|
||||
|
||||
#define DEBUG_TYPE "shape-inference"
|
||||
|
||||
using llvm::MutableArrayRef;
|
||||
using llvm::raw_ostream;
|
||||
using llvm::SmallVector;
|
||||
using llvm::SmallVectorImpl;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
using namespace mlir;
|
||||
using namespace toy;
|
||||
|
||||
namespace {
|
||||
|
||||
// clang-format off
|
||||
#include "toy/ShapeInferenceOpInterfaces.h.inc"
|
||||
/// Include the auto-generated definitions for the shape inference interfaces.
|
||||
#include "toy/ShapeInferenceOpInterfaces.cpp.inc"
|
||||
|
||||
namespace {
|
||||
/// The ShapeInferencePass is a FunctionPass that performs intra-procedural
|
||||
/// shape inference.
|
||||
///
|
||||
/// Algorithm:
|
||||
///
|
||||
/// 1) Build a worklist containing all the operations that are returning
|
||||
/// a generic Toy array: these are the operations that need shape
|
||||
/// 1) Build a worklist containing all the operations that return a
|
||||
/// dynamically shaped tensor: these are the operations that need shape
|
||||
/// inference.
|
||||
/// 2) Iterate on the worklist:
|
||||
/// a) find an operation to process: the next ready operation in the
|
||||
/// worklist has all of its arguments non-generic,
|
||||
/// b) if no operation is found, break out of the loop,
|
||||
/// c) remove the operation from the worklist,
|
||||
/// d) infer the shape of its output from the arguments type.
|
||||
/// 3) If the worklist is empty, the algorithm succeeded and we infer the
|
||||
/// return type for the function from the return operation.
|
||||
/// d) infer the shape of its output from the argument types.
|
||||
/// 3) If the worklist is empty, the algorithm succeeded.
|
||||
///
|
||||
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||
public:
|
||||
bool returnsGenericArray(Operation *op) {
|
||||
if (op->getNumResults() == 1) {
|
||||
if (!op->getResult(0)->getType().isa<ShapedType>())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
auto f = getFunction();
|
||||
|
||||
// Populate the worklist with the operations that need shape inference:
|
||||
// these are operations that return a generic array.
|
||||
// these are operations that return a dynamic shape.
|
||||
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
|
||||
f.walk([&](mlir::Operation *op) {
|
||||
if (returnsGenericArray(op)) {
|
||||
if (returnsDynamicShape(op))
|
||||
opWorklist.insert(op);
|
||||
}
|
||||
});
|
||||
|
||||
// Iterate on the operations in the worklist until all operations have been
|
||||
|
@ -91,15 +71,14 @@ public:
|
|||
while (!opWorklist.empty()) {
|
||||
// Find the next operation ready for inference, that is an operation
|
||||
// with all operands already resolved (non-generic).
|
||||
auto nextop = llvm::find_if(opWorklist, [this](Operation *op) {
|
||||
return this->returnsGenericArray(op);
|
||||
});
|
||||
|
||||
auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
|
||||
if (nextop == opWorklist.end())
|
||||
break; // failure: no operations can be inferred.
|
||||
break;
|
||||
|
||||
Operation *op = *nextop;
|
||||
opWorklist.erase(op);
|
||||
|
||||
// Ask the operation to infer its output shapes.
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
|
||||
auto shapeOp = dyn_cast<ShapeInference>(op);
|
||||
shapeOp.inferShapes();
|
||||
|
@ -107,11 +86,19 @@ public:
|
|||
|
||||
// If the operation worklist isn't empty, this indicates a failure.
|
||||
if (!opWorklist.empty()) {
|
||||
f.emitError("Shape inference failed, ")
|
||||
<< opWorklist.size() << " operations couldn't be inferred\n";
|
||||
signalPassFailure();
|
||||
auto diag = f.emitError("Shape inference failed, ")
|
||||
<< opWorklist.size() << " operations couldn't be inferred\n";
|
||||
}
|
||||
}
|
||||
|
||||
/// A utility method that returns if the given operation has a dynamically
|
||||
/// shaped result.
|
||||
static bool returnsDynamicShape(Operation *op) {
|
||||
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
||||
return !resultType.isa<RankedTensorType>();
|
||||
});
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
|
Loading…
Reference in New Issue