Update Chapter 4 of the Toy tutorial

This Chapter now introduces and makes use of the Interface concept
in MLIR to demonstrate ShapeInference.
END_PUBLIC

Closes tensorflow/mlir#191

PiperOrigin-RevId: 275085151
This commit is contained in:
Sana Damani 2019-10-16 12:08:55 -07:00 committed by A. Unique TensorFlower
parent e88dbc8c95
commit 3940b90d84
26 changed files with 1128 additions and 1644 deletions

View File

@ -1,16 +1,29 @@
add_subdirectory(include)
set(LLVM_LINK_COMPONENTS
Support
)
set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include")
add_public_tablegen_target(ToyCh4CombineIncGen)
add_toy_chapter(toyc-ch4
toyc.cpp
parser/AST.cpp
mlir/MLIRGen.cpp
mlir/ToyDialect.cpp
mlir/Dialect.cpp
mlir/DeadFunctionEliminationPass.cpp
mlir/ShapeInferencePass.cpp
mlir/ToyCombine.cpp
)
add_dependencies(toyc-ch4 ToyCh4OpsIncGen)
add_dependencies(toyc-ch4 ToyCh4ShapeInferenceInterfaceIncGen)
add_dependencies(toyc-ch4 ToyCh4CombineIncGen)
include_directories(include/)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
target_link_libraries(toyc-ch4
PRIVATE
MLIRAnalysis

View File

@ -0,0 +1 @@
add_subdirectory(toy)

View File

@ -33,10 +33,9 @@
namespace toy {
/// A variable
/// A variable type with shape information.
struct VarType {
enum { TY_FLOAT, TY_INT } elt_ty;
std::vector<int> shape;
std::vector<int64_t> shape;
};
/// Base class for all expression nodes.
@ -50,9 +49,7 @@ public:
Expr_Var,
Expr_BinOp,
Expr_Call,
Expr_Print, // builtin
Expr_If,
Expr_For,
Expr_Print,
};
ExprAST(ExprASTKind kind, Location location)
@ -85,7 +82,7 @@ public:
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
};
///
/// Expression class for a literal value.
class LiteralExprAST : public ExprAST {
std::vector<std::unique_ptr<ExprAST>> values;
std::vector<int64_t> dims;
@ -116,7 +113,7 @@ public:
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
};
///
/// Expression class for defining a variable.
class VarDeclExprAST : public ExprAST {
std::string name;
VarType type;
@ -136,7 +133,7 @@ public:
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
};
///
/// Expression class for a return operator.
class ReturnExprAST : public ExprAST {
llvm::Optional<std::unique_ptr<ExprAST>> expr;

View File

@ -0,0 +1,9 @@
set(LLVM_TARGET_DEFINITIONS Ops.td)
mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
add_public_tablegen_target(ToyCh4OpsIncGen)
set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(ToyCh4ShapeInferenceInterfaceIncGen)

View File

@ -16,7 +16,7 @@
// =============================================================================
//
// This file implements the IR Dialect for the Toy language.
// See g3doc/Tutorials/Toy/Ch-3.md for more information.
// See g3doc/Tutorials/Toy/Ch-2.md for more information.
//
//===----------------------------------------------------------------------===//
@ -25,325 +25,30 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
class Builder;
}
namespace toy {
/// This is the definition of the Toy dialect. A dialect inherits from
/// mlir::Dialect and register custom operations and types (in its constructor).
/// It can also overriding general behavior of dialects exposed as virtual
/// method, for example regarding verification and parsing/printing.
/// mlir::Dialect and registers custom attributes, operations, and types (in its
/// constructor). It can also override some general behavior exposed via virtual
/// methods.
class ToyDialect : public mlir::Dialect {
public:
explicit ToyDialect(mlir::MLIRContext *ctx);
/// Parse a type registered to this dialect. Overriding this method is
/// required for dialects that have custom types.
/// Technically this is only needed to be able to round-trip to textual IR.
mlir::Type parseType(llvm::StringRef tyData,
mlir::Location loc) const override;
/// Print a type registered to this dialect. Overriding this method is
/// only required for dialects that have custom types.
/// Technically this is only needed to be able to round-trip to textual IR.
void printType(mlir::Type type, llvm::raw_ostream &os) const override;
/// Provide a utility accessor to the dialect namespace. This is used by
/// several utilities for casting between dialects.
static llvm::StringRef getDialectNamespace() { return "toy"; }
};
////////////////////////////////////////////////////////////////////////////////
/////////////////////// Custom Types for the Dialect ///////////////////////////
////////////////////////////////////////////////////////////////////////////////
namespace detail {
struct ToyArrayTypeStorage;
}
/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa.
enum ToyTypeKind {
// The enum starts at the range reserved for this dialect.
TOY_TYPE = mlir::Type::FIRST_TOY_TYPE,
TOY_ARRAY,
};
/// Type for Toy arrays.
/// In MLIR Types are reference to immutable and uniqued objects owned by the
/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued
/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and
/// provides the public facade API to interact with the type.
class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
detail::ToyArrayTypeStorage> {
public:
using Base::Base;
/// Returns the dimensions for this array, or and empty range for a generic
/// array.
llvm::ArrayRef<int64_t> getShape();
/// Predicate to test if this array is generic (shape haven't been inferred
/// yet).
bool isGeneric() { return getShape().empty(); }
/// Return the rank of this array (0 if it is generic).
int getRank() { return getShape().size(); }
/// Return the type of individual elements in the array.
mlir::Type getElementType();
/// Get the unique instance of this Type from the context.
/// A ToyArrayType is only defined by the shape of the array.
static ToyArrayType get(mlir::MLIRContext *context,
llvm::ArrayRef<int64_t> shape = {});
/// Support method to enable LLVM-style RTTI type casting.
static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
};
////////////////////////////////////////////////////////////////////////////////
//////////////////// Custom Operations for the Dialect /////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Constant operation turns a literal into an SSA value. The data is attached
/// to the operation as an attribute. For example:
///
/// %0 = "toy.constant"()
/// {value: dense<tensor<2x3xf64>, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>}
/// : () -> !toy.array<2, 3>
///
/// An operation inherits from `class Op` and specifies optional traits. Here we
/// indicate that `toy.constant` does not have any operands and returns a single
/// result. The traits provide some utilities methods for the operation, for
/// instance we will be able to use `getResult()`, but `getOperand()` won't be
/// available.
class ConstantOp : public mlir::Op<ConstantOp, mlir::OpTrait::ZeroOperands,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
public:
/// This is the name used by MLIR to match an operation to this class during
/// parsing.
static llvm::StringRef getOperationName() { return "toy.constant"; }
/// The operation can have extra verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populates the `state` that MLIR uses to create operations.
/// The `toy.constant` operation does not have arguments but attaches a
/// constant array as an attribute and returns it as an SSA value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
llvm::ArrayRef<int64_t> shape,
mlir::DenseElementsAttr value);
/// Similar to the one above, but takes a single float and returns a
/// !toy.array<1>.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::FloatAttr value);
/// Inherit constructor.
using Op::Op;
};
/// Generic calls represent calls to a user defined function that needs to
/// be specialized for the shape of its arguments. The callee name is attached
/// as a literal string as an attribute. The arguments list must match the
/// arguments expected by the callee. For example:
///
/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
/// : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array">
///
/// This is only valid if a function named "my_func" exists and takes two
/// arguments.
class GenericCallOp
: public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::OneResult> {
public:
/// MLIR will use this to register the operation with the parser/printer.
static llvm::StringRef getOperationName() { return "toy.generic_call"; }
/// Operations can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to the builder to allow:
/// mlir::Builder::create<GenericCallOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.generic_call` operation accepts a callee name and a list of
/// arguments for the call.
static void build(mlir::Builder *builder, mlir::OperationState &state,
llvm::StringRef callee,
llvm::ArrayRef<mlir::Value *> arguments);
/// Return the name of the callee.
llvm::StringRef getCalleeName();
/// Inherit constructor.
using Op::Op;
};
/// Return operations terminate blocks (and functions as well). They take a
/// single argument and the type must match the function return type.
class ReturnOp
: public mlir::Op<ReturnOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::ZeroResult, mlir::OpTrait::IsTerminator> {
public:
static llvm::StringRef getOperationName() { return "toy.return"; }
/// Operations can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.return` operation accepts an optional single array as an argument
/// and does not have any returned value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value = nullptr);
/// Return true if there is a returned value.
bool hasOperand() { return 0 != getNumOperands(); }
/// Helper to return the optional operand. Caller must check if the operand
/// is present before calling this.
mlir::Value *getOperand() { return getOperation()->getOperand(0); }
/// Inherit constructor.
using Op::Op;
};
/// The print builtin takes a single array argument and does not return any.
class PrintOp : public mlir::Op<PrintOp, mlir::OpTrait::OneOperand,
mlir::OpTrait::ZeroResult> {
public:
static llvm::StringRef getOperationName() { return "toy.print"; }
/// Operations can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.print` operation accepts a single array as argument and does
/// not have any returned value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value);
/// Inherit constructor.
using Op::Op;
};
class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
public:
static llvm::StringRef getOperationName() { return "toy.transpose"; }
/// Operation can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<TransposeOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.transpose` operation accepts a single array as argument and
/// returns the transposed array as its only result.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value);
// Register our patterns for rewrite by the Canonicalization framework.
static void
getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
mlir::MLIRContext *context);
/// Inherit constructor.
using Op::Op;
};
/// Reshape operation is transforming its input array into a new array with the
/// same number of elements but different shapes. For example:
///
/// %0 = "toy.reshape"(%arg1) : (!toy.array<10>) -> !toy.array<5, 2>
///
class ReshapeOp : public mlir::Op<ReshapeOp, mlir::OpTrait::OneOperand,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
public:
static llvm::StringRef getOperationName() { return "toy.reshape"; }
/// Operation can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<ReshapeOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.reshape` operation accepts a single array as argument and
/// returns the array with the specified reshapedType as its only result.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value, ToyArrayType reshapedType);
// Register our patterns for rewrite by the Canonicalization framework.
static void
getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
mlir::MLIRContext *context);
/// Inherit constructor.
using Op::Op;
};
/// Binary operation implementing a multiplication. For two-dimensional array
/// a matrix multiplication is implemented, while for one dimensional array a
/// dot product is performed.
class MulOp : public mlir::Op<MulOp, mlir::OpTrait::NOperands<2>::Impl,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
public:
static llvm::StringRef getOperationName() { return "toy.mul"; }
/// Operation can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.mul` operation accepts two operands as argument and returns
/// a single value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs);
/// Convenience accessor for LHS of the expression.
mlir::Value *getLHS() { return getOperand(0); }
/// Convenience accessor for RHS of the expression.
mlir::Value *getRHS() { return getOperand(1); }
/// Inherit constructor.
using Op::Op;
};
/// Element wise addition of two arrays. The shape must match.
class AddOp : public mlir::Op<AddOp, mlir::OpTrait::NOperands<2>::Impl,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
public:
static llvm::StringRef getOperationName() { return "toy.add"; }
/// Operation can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.mul` operation accepts two operands as argument and returns
/// a single value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs);
/// Convenience accessor for LHS of the expression.
mlir::Value *getLHS() { return getOperand(0); }
/// Convenience accessor for RHS of the expression.
mlir::Value *getRHS() { return getOperand(1); }
/// Inherit constructor.
using Op::Op;
};
/// Include the auto-generated header file containing the declarations of the
/// toy operations.
#define GET_OP_CLASSES
#include "toy/Ops.h.inc"
} // end namespace toy
} // end namespace mlir
#endif // MLIR_TUTORIAL_TOY_DIALECT_H_

View File

@ -31,7 +31,7 @@ namespace toy {
/// Structure definition a location in a file.
struct Location {
std::shared_ptr<std::string> file; ///< filename
std::shared_ptr<std::string> file; ///< filename.
int line; ///< line number.
int col; ///< column number.
};

View File

@ -0,0 +1,285 @@
//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Defines the operations of the Toy dialect.
//
//===----------------------------------------------------------------------===//
#ifdef TOY_OPS
#else
#define TOY_OPS
#ifdef SHAPE_INFERENCE_INTERFACE
#else
include "toy/ShapeInferenceInterface.td"
#endif // SHAPE_INFERENCE_INTERFACE
// Provide a definition of the 'toy' dialect in the ODS framework so that we
// can define our operations.
def Toy_Dialect : Dialect {
let name = "toy";
let cppNamespace = "toy";
}
// Base class for toy dialect operations. This operation inherits from the base
// `Op` class in OpBase.td, and provides:
// * The parent dialect of the operation.
// * The mnemonic for the operation, or the name without the dialect prefix.
// * A list of traits for the operation.
class Toy_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Toy_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// Toy Operations
//===----------------------------------------------------------------------===//
// We define a toy operation by inherting from our base 'Toy_Op' class above.
// Here we provide the mnemonic and a list of traits for the operation. The
// constant operation is marked as 'NoSideEffect' as it is a pure operation
// and may be removed if dead.
def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
// Provide a summary and description for this operation. This can be used to
// auto-generate documenatation of the operations within our dialect.
let summary = "constant";
let description = [{
Constant operation turns a literal into an SSA value. The data is attached
to the operation as an attribute. For example:
```mlir
%0 = "toy.constant"()
{ value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
: () -> tensor<2x3xf64>
```
}];
// The constant operation takes an attribute as the only input.
let arguments = (ins F64ElementsAttr:$value);
// The constant operation returns a single value of TensorType.
let results = (outs F64Tensor);
// Add custom build methods for the constant operation. These method populates
// the `state` that MLIR uses to create operations, i.e. these are used when
// using `builder.create<ConstantOp>(...)`.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
"DenseElementsAttr value", [{
build(builder, result, value.getType(), value);
}]>,
// Build a constant with a given constant floating-point value.
OpBuilder<"Builder *builder, OperationState &result, double value", [{
buildConstantOp(builder, result, value);
}]>
];
// Invoke a static verify method to verify this constant operation.
let verifier = [{ return ::verify(*this); }];
}
def AddOp : Toy_Op<"add", [NoSideEffect]> {
let summary = "element-wise addition operation";
let description = [{
The "add" operation performs element-wise addition between two tensors.
The shapes of the tensor operands are expected to match.
}];
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
let results = (outs F64Tensor);
// Allow building an AddOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildAddOp(b, result, lhs, rhs);
}]
>];
let extraClassDeclaration = [{
void inferShapes() {
getResult()->setType(getOperand(0)->getType());
return;
}
}];
}
def GenericCallOp : Toy_Op<"generic_call"> {
let summary = "generic call operation";
let description = [{
Generic calls represent calls to a user defined function that needs to
be specialized for the shape of its arguments. The callee name is attached
as a symbol reference via an attribute. The arguments list must match the
arguments expected by the callee. For example:
```mlir
%4 = "toy.generic_call"(%1, %3) {callee = @my_func}
: (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
```
This is only valid if a function named "my_func" exists and takes two
arguments.
}];
// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
// The generic call operation returns a single value of TensorType.
let results = (outs F64Tensor);
// Add custom build methods for the generic call operation.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
"StringRef callee, ArrayRef<Value *> arguments", [{
buildGenericCallOp(builder, result, callee, arguments);
}]>
];
}
def MulOp : Toy_Op<"mul", [NoSideEffect]> {
let summary = "element-wise multiplication operation";
let description = [{
The "mul" operation performs element-wise multiplication between two
tensors. The shapes of the tensor operands are expected to match.
}];
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
let results = (outs F64Tensor);
// Allow building a MulOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildMulOp(b, result, lhs, rhs);
}]
>];
let extraClassDeclaration = [{
void inferShapes() {
auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
auto lhsRank = lhs.getShape().size();
auto rhsRank = rhs.getShape().size();
if (lhsRank != rhsRank) {
return;
}
SmallVector<int64_t, 2> dims;
if (lhsRank == 1) {
// dot product, result shape is <1>
dims.push_back(1);
} else {
if (lhsRank != 2) {
return;
}
dims.push_back(lhs.getShape()[0]);
dims.push_back(rhs.getShape()[1]);
}
getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
return;
}
}];
}
def PrintOp : Toy_Op<"print"> {
let summary = "print operation";
let description = [{
The "print" builtin operation prints a given input tensor, and produces
no results.
}];
// The print operation takes an input tensor to print.
let arguments = (ins F64Tensor:$input);
}
def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> {
let summary = "tensor reshape operation";
let description = [{
Reshape operation is transforming its input tensor into a new tensor with
the same number of elements but different shapes. For example:
```mlir
%0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64>
```
}];
let arguments = (ins F64Tensor:$input);
let hasCanonicalizer = 1;
// We expect that the reshape operation returns a statically shaped tensor.
let results = (outs StaticShapeTensorOf<[F64]>);
}
def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
let summary = "return operation";
let description = [{
The "return" operation represents a return operation within a function.
The operation takes an optional tensor operand and produces no results.
The operand type must match the signature of the function that contains
the operation. For example:
```mlir
func @foo() -> tensor<2xf64> {
...
toy.return %0 : tensor<2xf64>
}
```
}];
// The return operation takes an optional input operand to return. This
// value must match the return type of the enclosing function.
let arguments = (ins Variadic<F64Tensor>:$input);
// Allow building a ReturnOp with no return operand.
let builders = [OpBuilder<
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
>];
// Provide extra utility definitions on the c++ operation class definition.
let extraClassDeclaration = [{
bool hasOperand() { return getNumOperands() != 0; }
}];
// Invoke a static verify method to verify this return operation.
let verifier = [{ return ::verify(*this); }];
}
def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
let summary = "transpose operation";
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor);
let hasCanonicalizer = 1;
// Allow building a TransposeOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
buildTransposeOp(b, result, input);
}]
>];
let extraClassDeclaration = [{
void inferShapes() {
SmallVector<int64_t, 2> dims;
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
dims.insert(dims.end(), arrayTy.getShape().begin(),
arrayTy.getShape().end());
if (dims.size() == 2)
std::swap(dims[0], dims[1]);
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
return;
}
}];
}
#endif // TOY_OPS

View File

@ -26,10 +26,11 @@
namespace mlir {
class Pass;
} // namespace mlir
namespace toy {
std::unique_ptr<mlir::Pass> createShapeInferencePass();
} // namespace toy
std::unique_ptr<Pass> createShapeInferencePass();
std::unique_ptr<Pass> createDeadFunctionEliminationPass();
} // end namespace toy
} // end namespace mlir
#endif // MLIR_TUTORIAL_TOY_PASSES_H

View File

@ -0,0 +1,38 @@
//===- ShapeInferenceInterface.td - Operation Interface for Shape Inference ----------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Defines the operations of the Shape Inference Op Interface.
//
//===----------------------------------------------------------------------===//
#ifdef SHAPE_INFERENCE_INTERFACE
#else
#define SHAPE_INFERENCE_INTERFACE
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
let methods = [
InterfaceMethod<"Infer output shape for the current operation.",
"void", "inferShapes", (ins), [{}]>
];
}
#endif // SHAPE_INFERENCE_INTERFACE

View File

@ -0,0 +1,61 @@
//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a Module level pass performing dead function
// elimination. This is required as a post-processing step after function
// inlining.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Verifier.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "toy/Passes.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
namespace {
class DeadFunctionEliminationPass
: public mlir::ModulePass<DeadFunctionEliminationPass> {
public:
void runOnModule() override {
std::string str = "main";
auto module = getModule();
for (auto &f : module) {
// eliminate dead functions that are not main
if (str.find(f.getName().getStringRef()) == std::string::npos)
f.erase();
}
}
};
} // namespace
/// Create a pass that eliminates inlined functions in toy.
std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() {
return std::make_unique<DeadFunctionEliminationPass>();
}

View File

@ -0,0 +1,190 @@
//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the dialect for the Toy IR: custom type parsing and
// operation verification.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
using namespace mlir::toy;
//===----------------------------------------------------------------------===//
// ToyInlinerInterface
//===----------------------------------------------------------------------===//
/// This class defines the interface for handling inlining with Toy
/// operations.
struct ToyInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
/// All operations within toy can be inlined.
bool isLegalToInline(Operation *, Region *,
BlockAndValueMapping &) const final {
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
/// Handle the given inlined terminator(toy.return) by replacing it with a new
/// operation as necessary.
void handleTerminator(Operation *op,
ArrayRef<Value *> valuesToRepl) const final {
// Only "toy.return" needs to be handled here.
auto returnOp = cast<ReturnOp>(op);
// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
}
};
//===----------------------------------------------------------------------===//
// ToyDialect
//===----------------------------------------------------------------------===//
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
>();
addInterfaces<ToyInlinerInterface>();
}
//===----------------------------------------------------------------------===//
// Toy Operations
//===----------------------------------------------------------------------===//
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
double value) {
auto dataType = builder->getTensorType({}, builder->getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
ConstantOp::build(builder, state, dataType, dataAttribute);
}
/// Verifier for constant operation.
static mlir::LogicalResult verify(ConstantOp op) {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
if (!resultType)
return success();
auto attrType = op.value().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return op.emitOpError(
"return type must match the one of the attached value "
"attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
}
for (int dim = 0; dim < attrType.getRank(); ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
return op.emitOpError(
"return type shape mismatches its attribute at dimension ")
<< dim << ": " << attrType.getShape()[dim]
<< " != " << resultType.getShape()[dim];
}
}
return mlir::success();
}
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(builder->getTensorType(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
static void buildGenericCallOp(mlir::Builder *builder,
mlir::OperationState &state, StringRef callee,
ArrayRef<mlir::Value *> arguments) {
// Generic call always returns an unranked Tensor initially.
state.addTypes(builder->getTensorType(builder->getF64Type()));
state.addOperands(arguments);
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
}
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(builder->getTensorType(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
static mlir::LogicalResult verify(ReturnOp op) {
// We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition.
auto function = cast<FuncOp>(op.getParentOp());
/// ReturnOps can only have a single optional operand.
if (op.getNumOperands() > 1)
return op.emitOpError() << "expects at most 1 return operand";
// The operand number and types must match the function signature.
const auto &results = function.getType().getResults();
if (op.getNumOperands() != results.size())
return op.emitOpError()
<< "does not return the same number of values ("
<< op.getNumOperands() << ") as the enclosing function ("
<< results.size() << ")";
// If the operation does not have an input, we are done.
if (!op.hasOperand())
return mlir::success();
auto inputType = *op.operand_type_begin();
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
resultType.isa<mlir::UnrankedTensorType>())
return mlir::success();
return op.emitError() << "type of return operand ("
<< *op.operand_type_begin()
<< ") doesn't match function result type ("
<< results.front() << ")";
}
static void buildTransposeOp(mlir::Builder *builder,
mlir::OperationState &state, mlir::Value *value) {
state.addTypes(builder->getTensorType(builder->getF64Type()));
state.addOperands(value);
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "toy/Ops.cpp.inc"

View File

@ -25,30 +25,30 @@
#include "toy/Dialect.h"
#include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
using namespace mlir::toy;
using namespace toy;
using llvm::ArrayRef;
using llvm::cast;
using llvm::dyn_cast;
using llvm::isa;
using llvm::makeArrayRef;
using llvm::ScopedHashTableScope;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
using std::make_unique;
namespace {
@ -57,56 +57,43 @@ namespace {
/// This will emit operations that are specific to the Toy language, preserving
/// the semantics of the language and (hopefully) allow to perform accurate
/// analysis and transformation based on these high level semantics.
///
/// At this point we take advantage of the "raw" MLIR APIs to create operations
/// that haven't been registered in any way with MLIR. These operations are
/// unknown to MLIR, custom passes could operate by string-matching the name of
/// these operations, but no other type checking or semantic is associated with
/// them natively by MLIR.
class MLIRGenImpl {
public:
MLIRGenImpl(mlir::MLIRContext &context) : context(context) {}
MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
/// Public API: convert the AST for a Toy module (source file) to an MLIR
/// Module.
mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
/// Module operation.
mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
// We create an empty MLIR module and codegen functions one at a time and
// add them to the module.
theModule = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &F : moduleAST) {
auto func = mlirGen(F);
if (!func)
return nullptr;
theModule->push_back(func);
theModule.push_back(func);
}
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
// this won't do much, but it should at least check some structural
// properties.
if (failed(mlir::verify(*theModule))) {
emitError(mlir::UnknownLoc::get(&context), "module verification error");
// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
// have on the Toy operations.
if (failed(mlir::verify(theModule))) {
theModule.emitError("Module verification error");
return nullptr;
}
return std::move(theModule);
return theModule;
}
private:
/// In MLIR (like in LLVM) a "context" object holds the memory allocation and
/// the ownership of many internal structure of the IR and provide a level
/// of "uniquing" across multiple modules (types for instance).
mlir::MLIRContext &context;
/// A "module" matches a Toy source file: containing a list of functions.
mlir::ModuleOp theModule;
/// A "module" matches a source file: it contains a list of functions.
mlir::OwningModuleRef theModule;
/// The builder is a helper class to create IR inside a function. It is
/// re-initialized every time we enter a function and kept around as a
/// convenience for emitting individual operations.
/// The builder is stateful, in particular it keeps an "insertion point":
/// this is where the next operations will be introduced.
std::unique_ptr<mlir::OpBuilder> builder;
/// The builder is a helper class to create IR inside a function. The builder
/// is stateful, in particular it keeeps an "insertion point": this is where
/// the next operations will be introduced.
mlir::OpBuilder builder;
/// The symbol table maps a variable name to a value in the current scope.
/// Entering a function creates a new scope, and the function arguments are
@ -116,37 +103,35 @@ private:
/// Helper conversion for a Toy AST location to an MLIR location.
mlir::Location loc(Location loc) {
return mlir::FileLineColLoc::get(mlir::Identifier::get(*loc.file, &context),
loc.line, loc.col, &context);
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
loc.col);
}
/// Declare a variable in the current scope, return true if the variable
/// Declare a variable in the current scope, return success if the variable
/// wasn't declared yet.
bool declare(llvm::StringRef var, mlir::Value *value) {
if (symbolTable.count(var)) {
return false;
}
mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) {
if (symbolTable.count(var))
return mlir::failure();
symbolTable.insert(var, value);
return true;
return mlir::success();
}
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());
// This is a generic function, the return type will be inferred later.
llvm::SmallVector<mlir::Type, 4> ret_types;
// Arguments type is uniformly a generic array.
// Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(),
func_type, /* attrs = */ {});
auto func_type = builder.getFunctionType(arg_types, llvm::None);
auto function = mlir::FuncOp::create(location, proto.getName(), func_type);
// Mark the function as generic: it'll require type specialization for every
// call site.
if (function.getNumArguments())
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
function.setAttr("toy.generic", builder.getUnitAttr());
return function;
}
@ -165,29 +150,39 @@ private:
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
auto &protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
llvm::zip(protoArgs, entryBlock.getArguments())) {
declare(std::get<0>(name_value)->getName(), std::get<1>(name_value));
if (failed(declare(std::get<0>(name_value)->getName(),
std::get<1>(name_value))))
return nullptr;
}
// Create a builder for the function, it will be used throughout the codegen
// to create operations in this function.
builder = std::make_unique<mlir::OpBuilder>(function.getBody());
// Set the insertion point in the builder to the beginning of the function
// body, it will be used throughout the codegen to create operations in this
// function.
builder.setInsertionPointToStart(&entryBlock);
// Emit the body of the function.
if (!mlirGen(*funcAST.getBody())) {
if (mlir::failed(mlirGen(*funcAST.getBody()))) {
function.erase();
return nullptr;
}
// Implicitly return void if no return statement was emited.
// Implicitly return void if no return statement was emitted.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
if (function.getBlocks().back().back().getName().getStringRef() !=
"toy.return") {
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
mlirGen(fakeRet);
ReturnOp returnOp;
if (!entryBlock.empty())
returnOp = dyn_cast<ReturnOp>(entryBlock.back());
if (!returnOp) {
builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
} else if (returnOp.hasOperand()) {
// Otherwise, if this return operation has an operand then add a result to
// the function.
function.setType(builder.getFunctionType(function.getType().getInputs(),
getType(VarType{})));
}
return function;
@ -206,11 +201,11 @@ private:
// and the result value is returned. If an error occurs we get a nullptr
// and propagate.
//
mlir::Value *L = mlirGen(*binop.getLHS());
if (!L)
mlir::Value *lhs = mlirGen(*binop.getLHS());
if (!lhs)
return nullptr;
mlir::Value *R = mlirGen(*binop.getRHS());
if (!R)
mlir::Value *rhs = mlirGen(*binop.getRHS());
if (!rhs)
return nullptr;
auto location = loc(binop.loc());
@ -218,123 +213,112 @@ private:
// support '+' and '*'.
switch (binop.getOp()) {
case '+':
return builder->create<AddOp>(location, L, R).getResult();
break;
return builder.create<AddOp>(location, lhs, rhs);
case '*':
return builder->create<MulOp>(location, L, R).getResult();
default:
emitError(location, "error: invalid binary operator '")
<< binop.getOp() << "'";
return nullptr;
return builder.create<MulOp>(location, lhs, rhs);
}
emitError(location, "invalid binary operator '") << binop.getOp() << "'";
return nullptr;
}
// This is a reference to a variable in an expression. The variable is
// expected to have been declared and so should have a value in the symbol
// table, otherwise emit an error and return nullptr.
/// This is a reference to a variable in an expression. The variable is
/// expected to have been declared and so should have a value in the symbol
/// table, otherwise emit an error and return nullptr.
mlir::Value *mlirGen(VariableExprAST &expr) {
if (symbolTable.count(expr.getName()))
return symbolTable.lookup(expr.getName());
emitError(loc(expr.loc()), "error: unknown variable '")
if (auto *variable = symbolTable.lookup(expr.getName()))
return variable;
emitError(loc(expr.loc()), "Error: unknown variable '")
<< expr.getName() << "'";
return nullptr;
}
// Emit a return operation, return true on success.
bool mlirGen(ReturnExprAST &ret) {
/// Emit a return operation. This will return failure if any generation fails.
mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
auto location = loc(ret.loc());
// `return` takes an optional expression, we need to account for it here.
if (!ret.getExpr().hasValue()) {
builder->create<ReturnOp>(location);
return true;
// 'return' takes an optional expression, handle that case here.
mlir::Value *expr = nullptr;
if (ret.getExpr().hasValue()) {
if (!(expr = mlirGen(*ret.getExpr().getValue())))
return mlir::failure();
}
auto *expr = mlirGen(*ret.getExpr().getValue());
if (!expr)
return false;
builder->create<ReturnOp>(location, expr);
return true;
// Otherwise, this return operation has zero operands.
builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
: ArrayRef<mlir::Value *>());
return mlir::success();
}
// Emit a literal/constant array. It will be emitted as a flattened array of
// data in an Attribute attached to a `toy.constant` operation.
// See documentation on [Attributes](LangRef.md#attributes) for more details.
// Here is an excerpt:
//
// Attributes are the mechanism for specifying constant data in MLIR in
// places where a variable is never allowed [...]. They consist of a name
// and a [concrete attribute value](#attribute-values). It is possible to
// attach attributes to operations, functions, and function arguments. The
// set of expected attributes, their structure, and their interpretation
// are all contextually dependent on what they are attached to.
//
// Example, the source level statement:
// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
// will be converted to:
// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64>
//
/// Emit a literal/constant array. It will be emitted as a flattened array of
/// data in an Attribute attached to a `toy.constant` operation.
/// See documentation on [Attributes](LangRef.md#attributes) for more details.
/// Here is an excerpt:
///
/// Attributes are the mechanism for specifying constant data in MLIR in
/// places where a variable is never allowed [...]. They consist of a name
/// and a concrete attribute value. The set of expected attributes, their
/// structure, and their interpretation are all contextually dependent on
/// what they are attached to.
///
/// Example, the source level statement:
/// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
/// will be converted to:
/// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
/// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
/// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
///
mlir::Value *mlirGen(LiteralExprAST &lit) {
auto location = loc(lit.loc());
// The attribute is a vector with an attribute per element (number) in the
// array, see `collectData()` below for more details.
std::vector<mlir::Attribute> data;
auto type = getType(lit.getDims());
// The attribute is a vector with a floating point value per element
// (number) in the array, see `collectData()` below for more details.
std::vector<double> data;
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
std::multiplies<int>()));
collectData(lit, data);
// FIXME: using a tensor type is a HACK here.
// Can we do differently without registering a dialect? Using a string blob?
mlir::Type elementType = mlir::FloatType::getF64(&context);
auto dataType = builder->getTensorType(lit.getDims(), elementType);
// The type of this attribute is tensor of 64-bit floating-point with the
// shape of the literal.
mlir::Type elementType = builder.getF64Type();
auto dataType = builder.getTensorType(lit.getDims(), elementType);
// This is the actual attribute that actually hold the list of values for
// this array literal.
auto dataAttribute = builder->getDenseElementsAttr(dataType, data)
.cast<mlir::DenseElementsAttr>();
// This is the actual attribute that holds the list of values for this
// tensor literal.
auto dataAttribute =
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
// Build the MLIR op `toy.constant`, only boilerplate below.
return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute)
.getResult();
// Build the MLIR op `toy.constant`.
return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
}
// Recursive helper function to accumulate the data that compose an array
// literal. It flattens the nested structure in the supplied vector. For
// example with this array:
// [[1, 2], [3, 4]]
// we will generate:
// [ 1, 2, 3, 4 ]
// Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`.
// Attributes are the way MLIR attaches constant to operations and functions.
void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) {
/// Recursive helper function to accumulate the data that compose an array
/// literal. It flattens the nested structure in the supplied vector. For
/// example with this array:
/// [[1, 2], [3, 4]]
/// we will generate:
/// [ 1, 2, 3, 4 ]
/// Individual numbers are represented as doubles.
/// Attributes are the way MLIR attaches constant to operations.
void collectData(ExprAST &expr, std::vector<double> &data) {
if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
for (auto &value : lit->getValues())
collectData(*value, data);
return;
}
assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
mlir::Type elementType = mlir::FloatType::getF64(&context);
auto attr = mlir::FloatAttr::getChecked(
elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc()));
data.push_back(attr);
data.push_back(cast<NumberExprAST>(expr).getValue());
}
// Emit a call expression. It emits specific operations for the `transpose`
// builtin. Other identifiers are assumed to be user-defined functions.
/// Emit a call expression. It emits specific operations for the `transpose`
/// builtin. Other identifiers are assumed to be user-defined functions.
mlir::Value *mlirGen(CallExprAST &call) {
llvm::StringRef callee = call.getCallee();
auto location = loc(call.loc());
std::string callee = call.getCallee();
if (callee == "transpose") {
if (call.getArgs().size() != 1) {
emitError(location, "MLIR codegen encountered an error: toy.transpose "
"does not accept multiple arguments");
return nullptr;
}
mlir::Value *arg = mlirGen(*call.getArgs()[0]);
return builder->create<TransposeOp>(location, arg).getResult();
}
// Codegen the operands first
// Codegen the operands first.
SmallVector<mlir::Value *, 4> operands;
for (auto &expr : call.getArgs()) {
auto *arg = mlirGen(*expr);
@ -342,34 +326,41 @@ private:
return nullptr;
operands.push_back(arg);
}
// Calls to user-defined function are mapped to a custom call that takes
// the callee name as an attribute.
return builder->create<GenericCallOp>(location, call.getCallee(), operands)
.getResult();
// Builting calls have their custom operation, meaning this is a
// straightforward emission.
if (callee == "transpose") {
if (call.getArgs().size() != 1) {
emitError(location, "MLIR codegen encountered an error: toy.transpose "
"does not accept multiple arguments");
return nullptr;
}
return builder.create<TransposeOp>(location, operands[0]);
}
// Otherwise this is a call to a user-defined function. Calls to ser-defined
// functions are mapped to a custom call that takes the callee name as an
// attribute.
return builder.create<GenericCallOp>(location, callee, operands);
}
// Emit a call expression. It emits specific operations for two builtins:
// transpose(x) and print(x). Other identifiers are assumed to be user-defined
// functions. Return false on failure.
bool mlirGen(PrintExprAST &call) {
/// Emit a print expression. It emits specific operations for two builtins:
/// transpose(x) and print(x).
mlir::LogicalResult mlirGen(PrintExprAST &call) {
auto *arg = mlirGen(*call.getArg());
if (!arg)
return false;
auto location = loc(call.loc());
builder->create<PrintOp>(location, arg);
return true;
return mlir::failure();
builder.create<PrintOp>(loc(call.loc()), arg);
return mlir::success();
}
// Emit a constant for a single number (FIXME: semantic? broadcast?)
/// Emit a constant for a single number (FIXME: semantic? broadcast?)
mlir::Value *mlirGen(NumberExprAST &num) {
auto location = loc(num.loc());
mlir::Type elementType = mlir::FloatType::getF64(&context);
auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
loc(num.loc()));
return builder->create<ConstantOp>(location, attr).getResult();
return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
}
// Dispatch codegen for the right expression subclass using RTTI.
/// Dispatch codegen for the right expression subclass using RTTI.
mlir::Value *mlirGen(ExprAST &expr) {
switch (expr.getKind()) {
case toy::ExprAST::Expr_BinOp:
@ -390,77 +381,75 @@ private:
}
}
// Handle a variable declaration, we'll codegen the expression that forms the
// initializer and record the value in the symbol table before returning it.
// Future expressions will be able to reference this variable through symbol
// table lookup.
/// Handle a variable declaration, we'll codegen the expression that forms the
/// initializer and record the value in the symbol table before returning it.
/// Future expressions will be able to reference this variable through symbol
/// table lookup.
mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
mlir::Value *value = nullptr;
auto location = loc(vardecl.loc());
if (auto init = vardecl.getInitVal()) {
value = mlirGen(*init);
if (!value)
return nullptr;
// We have the initializer value, but in case the variable was declared
// with specific shape, we emit a "reshape" operation. It will get
// optimized out later as needed.
if (!vardecl.getType().shape.empty()) {
value = builder
->create<ReshapeOp>(
location, value,
getType(vardecl.getType()).cast<ToyArrayType>())
.getResult();
}
} else {
auto init = vardecl.getInitVal();
if (!init) {
emitError(loc(vardecl.loc()),
"missing initializer in variable declaration");
"Missing initializer in variable declaration");
return nullptr;
}
// Register the value in the symbol table
declare(vardecl.getName(), value);
mlir::Value *value = mlirGen(*init);
if (!value)
return nullptr;
// We have the initializer value, but in case the variable was declared
// with specific shape, we emit a "reshape" operation. It will get
// optimized out later as needed.
if (!vardecl.getType().shape.empty()) {
value = builder.create<ReshapeOp>(loc(vardecl.loc()),
getType(vardecl.getType()), value);
}
// Register the value in the symbol table.
if (failed(declare(vardecl.getName(), value)))
return nullptr;
return value;
}
/// Codegen a list of expression, return false if one of them hit an error.
bool mlirGen(ExprASTList &blockAST) {
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
/// Codegen a list of expression, return failure if one of them hit an error.
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
ScopedHashTableScope<StringRef, mlir::Value *> var_scope(symbolTable);
for (auto &expr : blockAST) {
// Specific handling for variable declarations, return statement, and
// print. These can only appear in block list and not in nested
// expressions.
if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
if (!mlirGen(*vardecl))
return false;
return mlir::failure();
continue;
}
if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) {
if (!mlirGen(*ret))
return false;
return true;
}
if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
return mlirGen(*ret);
if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
if (!mlirGen(*print))
return false;
if (mlir::failed(mlirGen(*print)))
return mlir::success();
continue;
}
// Generic expression dispatch codegen.
if (!mlirGen(*expr))
return false;
return mlir::failure();
}
return true;
return mlir::success();
}
/// Build a type from a list of shape dimensions. Types are `array` followed
/// by an optional dimension list, example: array<2, 2>
/// They are wrapped in a `toy` dialect (see next chapter) and get printed:
/// !toy.array<2, 2>
template <typename T> mlir::Type getType(T shape) {
SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
return ToyArrayType::get(&context, shape64);
/// Build a tensor type from a list of shape dimensions.
mlir::Type getType(ArrayRef<int64_t> shape) {
// If the shape is empty, then this type is unranked.
if (shape.empty())
return builder.getTensorType(builder.getF64Type());
// Otherwise, we use the given shape.
return builder.getTensorType(shape, builder.getF64Type());
}
/// Build an MLIR type from a Toy AST variable type
/// (forward to the generic getType(T) above).
/// Build an MLIR type from a Toy AST variable type (forward to the generic
/// getType above).
mlir::Type getType(const VarType &type) { return getType(type.shape); }
};

View File

@ -1,4 +1,4 @@
//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===//
//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@ -15,22 +15,14 @@
// limitations under the License.
// =============================================================================
//
// This file implements a Module level pass performing interprocedural
// This file implements a Function level pass performing interprocedural
// propagation of array shapes through function specialization.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "toy/Dialect.h"
#include "toy/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
@ -39,48 +31,26 @@
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#define DEBUG_TYPE "toy-shape-inference"
#define DEBUG_TYPE "shape-inference"
using namespace toy;
using llvm::MutableArrayRef;
using llvm::raw_ostream;
using llvm::SmallVector;
using llvm::SmallVectorImpl;
using llvm::StringRef;
using llvm::Twine;
/// Create a mangled name for function specialization. We will simply append the
/// shape of the arguments to the function name. For example, calling
///
/// "toy.generic_call"(%1, %3) {callee: "foo"}
/// : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array">
///
/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could
/// have provided a function with a similar name, but we will claim this as a
/// feature: this allows the user to provide custom specializations!
static std::string mangle(StringRef funcName,
MutableArrayRef<mlir::OpOperand> operands) {
std::string mangledName;
mangledName.reserve(funcName.size() + operands.size() * 6);
mangledName = funcName;
for (auto &operand : operands) {
auto arrayTy = operand.get()->getType().cast<ToyArrayType>();
mangledName += "_";
mlir::interleave(
arrayTy.getShape(),
[&](int64_t dim) { mangledName += Twine(dim).str(); },
[&]() { mangledName += "x"; });
}
return mangledName;
}
using namespace mlir;
namespace {
/// The ShapeInferencePass is a ModulePass: it will run on the Module as a
/// whole. MLIR also supports FunctionPass which are restricted to modify a
/// single function at a time. This pass couldn't be a function pass due the
/// nature of its interprocedural transformations.
// clang-format off
#include "toy/ShapeInferenceOpInterfaces.h.inc"
#include "toy/ShapeInferenceOpInterfaces.cpp.inc"
/// The ShapeInferencePass is a FunctionPass that performs intra-procedural
/// shape inference.
///
/// The algorithm has two levels, first intra-procedurally:
/// Algorithm:
///
/// 1) Build a worklist containing all the operations that are returning
/// a generic Toy array: these are the operations that need shape
@ -94,132 +64,25 @@ namespace {
/// 3) If the worklist is empty, the algorithm succeeded and we infer the
/// return type for the function from the return operation.
///
/// There is a twist though: when a call to a generic function is encountered,
/// shape inference requires the return type of the callee to be inferred first.
/// At this point we need to run specialize the callee by cloning it. Here is
/// the inter-procedural flow:
///
/// 1) Keep a worklist of function to process. Start with function "main".
/// 2) While the worklist isn't empty:
/// a) Take the last inserted function in the worklist.
/// b) Run the intra-procedural shape inference on this function.
/// c) If the intra-procedural shape inference can't complete, it returns
/// a Function that needs to be inferred first. In this case, queue this
/// new function and continue. Otherwise the inference succeeded and we
/// can pop from the queue.
///
class ShapeInferencePass : public mlir::ModulePass<ShapeInferencePass> {
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
public:
// One entry in the inter-procedural worklist. It keeps track of the
// function to process, the mangled name for this specialization, and the
// types of the arguments on which to specialize.
struct FunctionToSpecialize {
mlir::FuncOp function;
std::string mangledName;
SmallVector<mlir::Type, 4> argumentsType;
};
void runOnModule() override {
auto module = getModule();
auto main = module.lookupSymbol<mlir::FuncOp>("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
"shape inference failed: can't find a main function\n");
signalPassFailure();
return;
}
/// Inter-procedural loop, initialize with `main` and iterate until we
/// successfully infer the full reachable call-graph from main.
SmallVector<FunctionToSpecialize, 8> worklist;
worklist.push_back({main, "", {}});
while (!worklist.empty()) {
if (failed(specialize(worklist)))
return;
}
// Delete any generic function left
// FIXME: we may want this as a separate pass.
for (mlir::FuncOp function :
llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
if (auto genericAttr =
function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
if (genericAttr.getValue())
function.erase();
}
bool returnsGenericArray(Operation *op) {
if (op->getNumResults() == 1) {
if (!op->getResult(0)->getType().isa<ShapedType>())
return true;
}
return false;
}
/// Run inference on a function. If a mangledName is provided, we need to
/// specialize the function: to this end clone it first.
mlir::LogicalResult
specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) {
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
mlir::FuncOp f = functionToSpecialize.function;
// Check if cloning for specialization is needed (usually anything but main)
// We will create a new function with the concrete types for the parameters
// and clone the body into it.
if (!functionToSpecialize.mangledName.empty()) {
if (getModule().lookupSymbol<mlir::FuncOp>(
functionToSpecialize.mangledName)) {
funcWorklist.pop_back();
// Function already specialized, move on.
return mlir::success();
}
// Create a new function with a generic array return type, it will be
// updated when the inference for the function body completes.
auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
{ToyArrayType::get(&getContext())},
&getContext());
auto newFunction =
mlir::FuncOp::create(f.getLoc(), functionToSpecialize.mangledName,
type, f.getDialectAttrs());
getModule().push_back(newFunction);
// Clone the function body
mlir::BlockAndValueMapping mapper;
f.cloneInto(newFunction, mapper);
LLVM_DEBUG({
llvm::dbgs() << "====== Cloned : \n";
f.dump();
llvm::dbgs() << "====== Into : \n";
newFunction.dump();
});
f = newFunction;
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// Remap the entry-block arguments
// FIXME: this seems like a bug in `cloneInto()` above?
auto &entryBlock = f.getBlocks().front();
int blockArgSize = entryBlock.getArguments().size();
assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
entryBlock.addArguments(f.getType().getInputs());
auto argList = entryBlock.getArguments();
for (int argNum = 0; argNum < blockArgSize; ++argNum) {
argList[0]->replaceAllUsesWith(argList[blockArgSize]);
entryBlock.eraseArgument(0);
}
assert(succeeded(mlir::verify(f)));
}
LLVM_DEBUG(llvm::dbgs()
<< "Run shape inference on : '" << f.getName() << "'\n");
auto *toyDialect = getContext().getRegisteredDialect("toy");
if (!toyDialect) {
emitError(mlir::UnknownLoc::get(&getContext()),
"Toy dialect is not registered");
signalPassFailure();
return mlir::failure();
}
void runOnFunction() override {
auto f = getFunction();
// Populate the worklist with the operations that need shape inference:
// these are the Toy operations that return a generic array.
// these are operations that return a generic array.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
f.walk([&](mlir::Operation *op) {
if (op->getDialect() == toyDialect) {
if (op->getNumResults() == 1 &&
op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
opWorklist.insert(op);
if (returnsGenericArray(op)) {
opWorklist.insert(op);
}
});
@ -228,154 +91,31 @@ public:
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
return !ty.cast<ToyArrayType>().isGeneric();
});
auto nextop = llvm::find_if(opWorklist, [this](Operation *op) {
return this->returnsGenericArray(op);
});
if (nextop == opWorklist.end())
break; // failure: no operations can be inferred.
mlir::Operation *op = *nextop;
Operation *op = *nextop;
opWorklist.erase(op);
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
// The add operation is trivial: propagate the input type as is.
if (auto addOp = llvm::dyn_cast<AddOp>(op)) {
op->getResult(0)->setType(op->getOperand(0)->getType());
continue;
}
// Transpose is easy: just invert the dimensions.
if (auto transpose = llvm::dyn_cast<TransposeOp>(op)) {
SmallVector<int64_t, 2> dims;
auto arrayTy = transpose.getOperand()->getType().cast<ToyArrayType>();
dims.insert(dims.end(), arrayTy.getShape().begin(),
arrayTy.getShape().end());
transpose.getResult()->setType(ToyArrayType::get(&getContext(), dims));
continue;
}
// Multiplication is a bit trickier, handle rank 1 as dot product and rank
// 2 as matrix multiplications.
// We need to be careful about rank mismatch here: the verifier could
// catch it but shape inference earlier in the pass could generate an
// invalid IR (from an invalid Toy input of course) and we wouldn't want
// to crash here.
if (auto mulOp = llvm::dyn_cast<MulOp>(op)) {
auto lhs = mulOp.getLHS()->getType().cast<ToyArrayType>();
auto rhs = mulOp.getRHS()->getType().cast<ToyArrayType>();
auto lhsRank = lhs.getShape().size();
auto rhsRank = rhs.getShape().size();
if (lhsRank != rhsRank) {
return mulOp.emitOpError(
"shape mismatch: LHS and RHS must have the same "
"rank for multiplication, got " +
Twine(lhsRank) + " vs " + Twine(lhsRank));
}
SmallVector<int64_t, 2> dims;
if (lhsRank == 1) {
// dot product, result shape is <1>
dims.push_back(1);
} else if (lhsRank != 2) {
return op->emitOpError(
"shape mismatch: expect rank 1 or 2 for mul operands, got " +
Twine(lhsRank));
} else {
dims.push_back(lhs.getShape()[0]);
dims.push_back(rhs.getShape()[1]);
}
op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
continue;
}
// Process calls: lookup the callee after mangling the name with the
// argument shapes. If the callee does not exist, we stop the inference
// for this function, queue the callee in the inter-procedural work list,
// and return. The current function stays in the work list and will
// restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName();
auto callee = getModule().lookupSymbol<mlir::FuncOp>(calleeName);
if (!callee) {
f.emitError("shape inference failed, call to unknown '")
<< calleeName << "'";
signalPassFailure();
return mlir::failure();
}
auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n");
auto mangledCallee =
getModule().lookupSymbol<mlir::FuncOp>(mangledName);
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
funcWorklist.push_back({callee, std::move(mangledName),
llvm::to_vector<4>(op->getOperandTypes())});
return mlir::success();
}
// Found a specialized callee! Let's turn this into a normal call
// operation.
SmallVector<mlir::Value *, 8> operands(op->getOperands());
mlir::OpBuilder builder(op);
auto newCall =
builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands);
if (newCall.getNumResults()) {
op->getResult(0)->replaceAllUsesWith(newCall.getResult(0));
op->erase();
continue;
}
}
auto shapeOp = dyn_cast<ShapeInference>(op);
shapeOp.inferShapes();
}
// Done with inference on this function, removing it from the worklist.
funcWorklist.pop_back();
// Mark the function as non-generic now that inference has succeeded
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) {
std::string str;
llvm::raw_string_ostream errorMsg(str);
errorMsg << "shape inference failed, " << opWorklist.size()
<< " operations couldn't be inferred\n";
for (auto *ope : opWorklist)
errorMsg << " - " << *ope << "\n";
f.emitError(errorMsg.str());
signalPassFailure();
return mlir::failure();
auto diag = f.emitError("Shape inference failed, ")
<< opWorklist.size() << " operations couldn't be inferred\n";
}
// Finally, update the return type of the function based on the argument to
// the return operation.
for (auto &block : f.getBlocks()) {
auto ret = llvm::cast<ReturnOp>(block.getTerminator());
if (!ret)
continue;
if (ret.getNumOperands() &&
f.getType().getResult(0) == ret.getOperand()->getType())
// type match, we're done
break;
SmallVector<mlir::Type, 1> retTy;
if (ret.getNumOperands())
retTy.push_back(ret.getOperand()->getType());
std::vector<mlir::Type> argumentsType;
for (auto arg : f.getArguments())
argumentsType.push_back(arg->getType());
auto newType =
mlir::FunctionType::get(argumentsType, retTy, &getContext());
f.setType(newType);
assert(succeeded(mlir::verify(f)));
break;
}
return mlir::success();
}
};
} // end anonymous namespace
namespace toy {
std::unique_ptr<mlir::Pass> createShapeInferencePass() {
/// Create a Shape Inference pass.
std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>();
}
} // namespace toy

View File

@ -15,24 +15,25 @@
// limitations under the License.
// =============================================================================
//
// This file implements a simple combiner for optimizing pattern in the Toy
// dialect.
// This file implements a set of simple combiners for optimizing operations in
// the Toy dialect.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "toy/Dialect.h"
#include <numeric>
namespace toy {
using namespace mlir;
using namespace toy;
namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "ToyCombine.inc"
} // end anonymous namespace
/// Fold transpose(transpose(x) -> transpose(x)
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process
@ -40,9 +41,9 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
SimplifyRedundantTranspose(mlir::MLIRContext *context)
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
/// This method is attempting to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected
/// to interact with it to perform any changes to the IR from here.
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
@ -50,106 +51,28 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
mlir::Value *transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
// If the input is defined by another Transpose, bingo!
if (!transposeInputOp)
return matchFailure();
// Use the rewriter to perform the replacement
// Use the rewriter to perform the replacement.
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
return matchSuccess();
}
};
/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(ReshapeOp reshape,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current reshape.
ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
reshape.getOperand()->getDefiningOp());
// If the input is defined by another constant, bingo!
if (!constantOp)
return matchFailure();
auto reshapeType = reshape.getType().cast<ToyArrayType>();
if (auto valueAttr =
constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
// FIXME Check matching of element count!
// auto oldType = constantOp.getType();
auto newType = rewriter.getTensorType(
reshapeType.getShape(), valueAttr.getType().getElementType());
auto newAttr = valueAttr.reshape(newType);
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else if (auto valueAttr =
constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
// Broadcast
auto dataSize = std::accumulate(reshapeType.getShape().begin(),
reshapeType.getShape().end(), 1,
std::multiplies<int>());
std::vector<mlir::Attribute> data(dataSize, valueAttr);
auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
reshapeType.getElementType());
auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else {
llvm_unreachable("Unsupported Constant format");
}
return matchSuccess();
}
};
/// Fold reshape(reshape(x)) -> reshape(x)
struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(ReshapeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current reshape.
mlir::Value *reshapeInput = op.getOperand();
// If the input is defined by another reshape, bingo!
if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
return matchFailure();
// Use the rewriter to perform the replacement
rewriter.replaceOp(op, {reshapeInput});
return matchSuccess();
}
};
/// Fold reshape(x)) -> x, when input type matches output type
struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(ReshapeOp op,
mlir::PatternRewriter &rewriter) const override {
if (op.getOperand()->getType() != op.getType())
return matchFailure();
rewriter.replaceOp(op, {op.getOperand()});
return matchSuccess();
}
};
} // end anonymous namespace.
// Register our patterns for rewrite by the Canonicalization framework.
void TransposeOp::getCanonicalizationPatterns(
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
/// that they can be picked up by the Canonicalization framework.
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SimplifyRedundantTranspose>(context);
}
// Register our patterns for rewrite by the Canonicalization framework.
void ReshapeOp::getCanonicalizationPatterns(
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<SimplifyReshapeConstant, SimplifyReshapeReshape,
SimplifyNullReshape>(context);
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
/// that they can be picked up by the Canonicalization framework.
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
FoldConstantReshapeOptPattern>(context);
}
} // namespace toy

View File

@ -0,0 +1,73 @@
//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Defines language-specific pattern match optimizations for Toy using
// Declarative Rewrite Rules (DRR) specified using TableGen records.
//
//===----------------------------------------------------------------------===//
#ifndef TOY_COMBINE
#define TOY_COMBINE
#ifndef OP_BASE
include "toy/Ops.td"
#endif // OP_BASE
/// Note: The DRR definition used for defining patterns is shown below:
///
/// class Pattern<
/// dag sourcePattern, list<dag> resultPatterns,
/// list<dag> additionalConstraints = [],
/// dag benefitsAdded = (addBenefit 0)
/// >;
//===----------------------------------------------------------------------===//
// Basic Pattern-Match and Rewrite
//===----------------------------------------------------------------------===//
// Reshape(Reshape(x)) = x
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
(ReshapeOp $arg)>;
//===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite using Native Code Call
//===----------------------------------------------------------------------===//
// Native Code Calls may be used for more complex transformations using inline
// C++ and C++ helper functions.
// Reshape(Constant(x)) = x'
def ReshapeConstant :
NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
def FoldConstantReshapeOptPattern : Pat<
(ReshapeOp:$res (ConstantOp $arg)),
(ConstantOp (ReshapeConstant $arg, $res))>;
//===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite with Constraints
//===----------------------------------------------------------------------===//
// DRR allows for constraint checking when the transformation is conditional
// on operand properties.
// Reshape(x) = x, where input and output shapes are identical
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
def RedundantReshapeOptPattern : Pat<
(ReshapeOp:$res $arg), (replaceWithValue $arg),
[(TypesAreIdentical $res, $arg)]>;
#endif // TOY_COMBINE

View File

@ -1,387 +0,0 @@
//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the dialect for the Toy IR: custom type parsing and
// operation verification.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
using llvm::ArrayRef;
using llvm::raw_ostream;
using llvm::raw_string_ostream;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
namespace toy {
namespace detail {
/// This class holds the implementation of the ToyArrayType.
/// It is intended to be uniqued based on its content and owned by the context.
struct ToyArrayTypeStorage : public mlir::TypeStorage {
/// This defines how we unique this type in the context: our key contains
/// only the shape, a more complex type would have multiple entries in the
/// tuple here.
/// The element of the tuples usually matches 1-1 the arguments from the
/// public `get()` method arguments from the facade.
using KeyTy = std::tuple<ArrayRef<int64_t>>;
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(std::get<0>(key));
}
/// When the key hash hits an existing type, we compare the shape themselves
/// to confirm we have the right type.
bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); }
/// This is a factory method to create our type storage. It is only
/// invoked after looking up the type in the context using the key and not
/// finding it.
static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
const KeyTy &key) {
// Copy the shape array into the bumpptr allocator owned by the context.
ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
// Allocate the instance for the ToyArrayTypeStorage itself
auto *storage = allocator.allocate<ToyArrayTypeStorage>();
// Initialize the instance using placement new.
return new (storage) ToyArrayTypeStorage(shape);
}
ArrayRef<int64_t> getShape() const { return shape; }
private:
ArrayRef<int64_t> shape;
/// Constructor is only invoked from the `construct()` method above.
ToyArrayTypeStorage(ArrayRef<int64_t> shape) : shape(shape) {}
};
} // namespace detail
mlir::Type ToyArrayType::getElementType() {
return mlir::FloatType::getF64(getContext());
}
ToyArrayType ToyArrayType::get(mlir::MLIRContext *context,
ArrayRef<int64_t> shape) {
return Base::get(context, ToyTypeKind::TOY_ARRAY, shape);
}
ArrayRef<int64_t> ToyArrayType::getShape() { return getImpl()->getShape(); }
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
addOperations<ConstantOp, GenericCallOp, PrintOp, TransposeOp, ReshapeOp,
MulOp, AddOp, ReturnOp>();
addTypes<ToyArrayType>();
}
/// Parse a type registered to this dialect, we expect only Toy arrays.
mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const {
// Sanity check: we only support array or array<...>
if (!tyData.startswith("array")) {
emitError(loc, "invalid Toy type '" + tyData + "', array expected");
return nullptr;
}
// Drop the "array" prefix from the type name, we expect either an empty
// string or just the shape.
tyData = tyData.drop_front(StringRef("array").size());
// This is the generic array case without shape, early return it.
if (tyData.empty())
return ToyArrayType::get(getContext());
// Use a regex to parse the shape (for efficient we should store this regex in
// the dialect itself).
SmallVector<StringRef, 4> matches;
auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
if (!shapeRegex.match(tyData, &matches)) {
emitError(loc, "invalid toy array shape '" + tyData + "'");
return nullptr;
}
SmallVector<int64_t, 4> shape;
// Iterate through the captures, skip the first one which is the full string.
for (auto dimStr :
llvm::make_range(std::next(matches.begin()), matches.end())) {
if (dimStr.startswith(","))
continue; // POSIX misses non-capturing groups.
if (dimStr.empty())
continue; // '*' makes it an optional group capture
// Convert the capture to an integer
unsigned long long dim;
if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
emitError(loc, "couldn't parse dimension as integer, matched: " + dimStr);
return mlir::Type();
}
shape.push_back(dim);
}
// Finally we collected all the dimensions in the shape,
// create the array type.
return ToyArrayType::get(getContext(), shape);
}
/// Print a Toy array type, for example `array<2, 3, 4>`
void ToyDialect::printType(mlir::Type type, raw_ostream &os) const {
auto arrayTy = type.dyn_cast<ToyArrayType>();
if (!arrayTy) {
os << "unknown toy type";
return;
}
os << "array";
if (!arrayTy.getShape().empty()) {
os << "<";
mlir::interleaveComma(arrayTy.getShape(), os);
os << ">";
}
}
////////////////////////////////////////////////////////////////////////////////
//////////////////// Custom Operations for the Dialect /////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Helper to verify that the result of an operation is a Toy array type.
template <typename T> static mlir::LogicalResult verifyToyReturnArray(T *op) {
if (!op->getResult()->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its argument, got "
<< op->getResult()->getType();
return op->emitOpError(os.str());
}
return mlir::success();
}
/// Helper to verify that the two operands of a binary operation are Toy
/// arrays..
template <typename T> static mlir::LogicalResult verifyToyBinOperands(T *op) {
if (!op->getOperand(0)->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its LHS, got "
<< op->getOperand(0)->getType();
return op->emitOpError(os.str());
}
if (!op->getOperand(1)->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its LHS, got "
<< op->getOperand(0)->getType();
return op->emitOpError(os.str());
}
return mlir::success();
}
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
ArrayRef<int64_t> shape, mlir::DenseElementsAttr value) {
state.types.push_back(ToyArrayType::get(builder->getContext(), shape));
auto dataAttribute = builder->getNamedAttr("value", value);
state.attributes.push_back(dataAttribute);
}
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::FloatAttr value) {
// Broadcast and forward to the other build factory
mlir::Type elementType = mlir::FloatType::getF64(builder->getContext());
auto dataType = builder->getTensorType({1}, elementType);
auto dataAttribute = builder->getDenseElementsAttr(dataType, {value})
.cast<mlir::DenseElementsAttr>();
ConstantOp::build(builder, state, {1}, dataAttribute);
}
/// Verifier for constant operation.
mlir::LogicalResult ConstantOp::verify() {
// Ensure that the return type is a Toy array
if (failed(verifyToyReturnArray(this)))
return mlir::failure();
// We expect the constant itself to be stored as an attribute.
auto dataAttr = getAttr("value").dyn_cast<mlir::DenseElementsAttr>();
if (!dataAttr) {
return emitOpError(
"missing valid `value` DenseElementsAttribute on toy.constant()");
}
auto attrType = dataAttr.getType().dyn_cast<mlir::TensorType>();
if (!attrType) {
return emitOpError(
"missing valid `value` DenseElementsAttribute on toy.constant()");
}
// If the return type of the constant is not a generic array, the shape must
// match the shape of the attribute holding the data.
auto resultType = getResult()->getType().cast<ToyArrayType>();
if (!resultType.isGeneric()) {
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("The rank of the toy.constant return type must match "
"the one of the attached value attribute: " +
Twine(attrType.getRank()) +
" != " + Twine(resultType.getRank()));
}
for (int dim = 0; dim < attrType.getRank(); ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
std::string msg;
raw_string_ostream os(msg);
return emitOpError(
"Shape mismatch between toy.constant return type and its "
"attribute at dimension " +
Twine(dim) + ": " + Twine(attrType.getShape()[dim]) +
" != " + Twine(resultType.getShape()[dim]));
}
}
}
return mlir::success();
}
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
StringRef callee, ArrayRef<mlir::Value *> arguments) {
// Generic call always returns a generic ToyArray initially
state.types.push_back(ToyArrayType::get(builder->getContext()));
state.operands.assign(arguments.begin(), arguments.end());
auto calleeAttr = builder->getStringAttr(callee);
state.attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
}
mlir::LogicalResult GenericCallOp::verify() {
// Verify that every operand is a Toy Array
for (int opId = 0, num = getNumOperands(); opId < num; ++opId) {
if (!getOperand(opId)->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its " << opId << " operand, got "
<< getOperand(opId)->getType();
return emitOpError(os.str());
}
}
return mlir::success();
}
/// Return the name of the callee.
StringRef GenericCallOp::getCalleeName() {
return getAttr("callee").cast<mlir::StringAttr>().getValue();
}
template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
if (!op->getOperand()->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its argument, got "
<< op->getOperand()->getType();
return op->emitOpError(os.str());
}
return mlir::success();
}
void ReturnOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
// Return does not return any value and has an optional single argument
if (value)
state.operands.push_back(value);
}
mlir::LogicalResult ReturnOp::verify() {
if (getNumOperands() > 1)
return emitOpError("expects zero or one operand, got " +
Twine(getNumOperands()));
if (hasOperand() && failed(verifyToySingleOperand(this)))
return mlir::failure();
return mlir::success();
}
void PrintOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
// Print does not return any value and has a single argument
state.operands.push_back(value);
}
mlir::LogicalResult PrintOp::verify() {
if (failed(verifyToySingleOperand(this)))
return mlir::failure();
return mlir::success();
}
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
state.types.push_back(ToyArrayType::get(builder->getContext()));
state.operands.push_back(value);
}
mlir::LogicalResult TransposeOp::verify() {
if (failed(verifyToySingleOperand(this)))
return mlir::failure();
return mlir::success();
}
void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value, ToyArrayType reshapedType) {
state.types.push_back(reshapedType);
state.operands.push_back(value);
}
mlir::LogicalResult ReshapeOp::verify() {
if (failed(verifyToySingleOperand(this)))
return mlir::failure();
auto retTy = getResult()->getType().dyn_cast<ToyArrayType>();
if (!retTy)
return emitOpError("toy.reshape is expected to produce a Toy array");
if (retTy.isGeneric())
return emitOpError("toy.reshape is expected to produce a shaped Toy array, "
"got a generic one.");
return mlir::success();
}
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.types.push_back(ToyArrayType::get(builder->getContext()));
state.operands.push_back(lhs);
state.operands.push_back(rhs);
}
mlir::LogicalResult AddOp::verify() {
if (failed(verifyToyBinOperands(this)))
return mlir::failure();
return mlir::success();
}
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.types.push_back(ToyArrayType::get(builder->getContext()));
state.operands.push_back(lhs);
state.operands.push_back(rhs);
}
mlir::LogicalResult MulOp::verify() {
if (failed(verifyToyBinOperands(this)))
return mlir::failure();
return mlir::success();
}
} // namespace toy

View File

@ -80,54 +80,63 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.ParseModule();
}
mlir::LogicalResult optimize(mlir::ModuleOp module) {
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
// Apply any generic pass manager command line options.
applyPassManagerCLOptions(pm);
int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// Handle '.toy' input to the compiler.
if (inputType != InputType::MLIR &&
!llvm::StringRef(inputFilename).endswith(".mlir")) {
auto moduleAST = parseInputFile(inputFilename);
module = mlirGen(context, *moduleAST);
return !module ? 1 : 0;
}
return pm.run(module);
// Otherwise, the input is '.mlir'.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
return -1;
}
// Parse the input mlir.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return 3;
}
return 0;
}
int dumpMLIR() {
// Register our Dialect with MLIR
mlir::registerDialect<ToyDialect>();
// Register our Dialect with MLIR.
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (inputType == InputType::MLIR ||
llvm::StringRef(inputFilename).endswith(".mlir")) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
return -1;
}
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return 3;
}
if (failed(mlir::verify(*module))) {
llvm::errs() << "Error verifying MLIR module\n";
return 4;
}
} else {
auto moduleAST = parseInputFile(inputFilename);
module = mlirGen(context, *moduleAST);
}
if (!module)
return 1;
if (int error = loadMLIR(context, module))
return error;
if (EnableOpt) {
if (failed(optimize(*module))) {
llvm::errs() << "Module optimization failed\n";
return 7;
}
mlir::PassManager pm(&context);
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);
// Add a run of the canonicalizer to optimize the mlir module.
pm.addPass(mlir::createCanonicalizerPass());
// Inline all functions into main and then delete them.
pm.addPass(mlir::createInlinerPass());
pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
// Now that there is only one function, we can infer the shapes of each of
// the operations.
pm.addPass(mlir::toy::createShapeInferencePass());
if (mlir::failed(pm.run(*module)))
return 4;
}
module->dump();
return 0;
}

View File

@ -1,242 +1,118 @@
# Chapter 4: High-level Language-Specific Analysis and Transformation
# Chapter 4: Using Interfaces
Creating a dialect that closely represents the semantics of an input language
enables analyses and transformations in MLIR that are generally performed on the
language AST. For example, `clang` has a fairly
[heavy mechanism](https://clang.llvm.org/doxygen/classclang_1_1TreeTransform.html)
for performing template instantiation in C++.
[Interfaces](../../Interfaces.md) provide a generic method for applying
transformations across dialects. We first describe how to leverage an existing
MLIR interface, and then walk through writing your own interface.
Another aspect is optimization. While some previous language specific
optimizations have been implemented in LLVM (like the
[ARC optimizer](http://llvm.org/doxygen/ObjCARCOpts_8cpp_source.html#l00468)),
it has been at the cost of relying on either adding enough concepts in LLVM, to
be able to embed the high-level semantics of the input, or using fragile
"best-effort" metadata to decorate the IR with the information needed for these
custom optimizations.
## Function Inlining
We show in this chapter how to leverage the Toy Dialect and its high-level
semantics to perform transformations that would be difficult in LLVM: first a
simple combine of two redundant operations, and second a full interprocedural
shape inference with function specialization.
In order to apply function inlining in the Toy dialect, we override the
DialectInlinerInterface in Toy, enable inlining and add special handling for the
return operation:
# Basic Optimization: Eliminate Redundant Transpose
```Toy(.cpp)
//===----------------------------------------------------------------------===//
// ToyInlinerInterface
//===----------------------------------------------------------------------===//
Let's start with a simple pattern and try to eliminate a sequence of two
transpose that cancel out: `transpose(transpose(X)) -> X`. Here is the
corresponding Toy example:
/// This class defines the interface for handling inlining with Toy
/// operations.
struct ToyInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
```Toy(.toy)
def transpose_transpose(x) {
return transpose(transpose(x));
}
```
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
Which corresponds to the following IR:
```MLIR(.mlir)
func @transpose_transpose(%arg0: !toy<"array">)
attributes {toy.generic: true} {
%0 = "toy.transpose"(%arg0) : (!toy<"array">) -> !toy<"array">
%1 = "toy.transpose"(%0) : (!toy<"array">) -> !toy<"array">
"toy.return"(%1) : (!toy<"array">) -> ()
}
```
This is a good example of a transformation that is trivial to match on the Toy
IR but that would be quite hard for LLVM to figure. For example today clang
can't optimize away the temporary array and the computation with the naive
transpose expressed with these loops:
```c++
#define N 100
#define M 100
void sink(void *);
void double_transpose(int A[N][M]) {
int B[M][N];
for(int i = 0; i < N; ++i) {
for(int j = 0; j < M; ++j) {
B[j][i] = A[i][j];
}
/// All operations within toy can be inlined.
bool isLegalToInline(Operation *, Region *,
BlockAndValueMapping &) const final {
return true;
}
for(int i = 0; i < N; ++i) {
for(int j = 0; j < M; ++j) {
A[i][j] = B[j][i];
}
}
sink(A);
}
```
For simple rewrite involving matching a tree-like pattern in the IR and
replacing it with a different set of operations, we can plug into the MLIR
`Canonicalizer` pass by implementing a `RewritePattern`:
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
```c++
/// Fold transpose(transpose(x)) -> x
struct SimplifyRedundantTranspose : public mlir::RewritePattern {
/// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process
/// them in order of profitability.
SimplifyRedundantTranspose(mlir::MLIRContext *context)
: RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1, context) {}
/// Handle the given inlined terminator(toy.return) by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op,
ArrayRef<Value *> valuesToRepl) const final {
// Only "toy.return" needs to be handled here.
auto returnOp = cast<ReturnOp>(op);
/// This method is attempting to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected
/// to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult matchAndRewrite(
mlir::Operation *op, mlir::PatternRewriter &rewriter) const override {
// We can directly cast the current operation as this will only get invoked
// on TransposeOp.
TransposeOp transpose = op->cast<TransposeOp>();
// look through the input to the current transpose
mlir::Value *transposeInput = transpose.getOperand();
// If the input is defined by another Transpose, bingo!
if (!matchPattern(transposeInput, mlir::m_Op<TransposeOp>()))
return matchFailure();
auto transposeInputOp =
transposeInput->getDefiningOp()->cast<TransposeOp>();
// Use the rewriter to perform the replacement
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
return matchSuccess();
// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
}
};
```
Let's see how to improve our `TransposeOp` by extending it with a new static
method:
Next, we call into the interface by adding an inliner pass to the pass manager
for toy:
```c++
/// This hook returns any canonicalization pattern rewrites that the operation
/// supports, for use by the canonicalization pass.
static void getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
mlir::MLIRContext *context) {
results.push_back(std::make_unique<SimplifyRedundantTranspose>(context));
```Toy(.cpp)
pm.addPass(mlir::createInlinerPass());
```
** Insert example here **
## Shape Inference
The Toy language allows for implicit shapes and hence requires shape inference.
We implement shape inference as a generic
[Operation Interface](../../Interfaces.md#operation-interfaces).
1. We first create the ShapeInferenceOpInterface by specializing the
OpInterface class using [ODS](../../OpDefinitions.md#operation-interfaces).
This class defines interface methods that Toy operations must override for
shape inference.
```Toy(.cpp)
def ShapeInferenceOpInterface : OpInterface<"ShapeInferenceOpInterface"> {
let methods = [
InterfaceMethod<
"bool", "returnsGenericArray", (ins), [{
if (getNumResults() == 1) {
auto arrayTy = op.getResult()->getType().cast<RankedTensorType>();
return arrayTy.getShape().empty();
}
return false;
}]>,
InterfaceMethod<"void", "inferShapes", (ins), [{}]>
];
}
```
1. Next, we override the inferShapes() method within Toy operations. As an
example, for the transpose op, the result shape is inferred by swapping the
dimensions of the input tensor.
```Toy(.cpp)
void inferShapes() {
SmallVector<int64_t, 2> dims;
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
dims.insert(dims.end(), arrayTy.getShape().begin(),
arrayTy.getShape().end());
if (dims.size() == 2)
std::swap(dims[0], dims[1]);
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
return;
}
```
The implementation of this rewriter is in `ToyCombine.cpp`. We also need to
update our main file, `toyc.cpp`, to add an optimization pipeline. In MLIR, the
optimizations are ran through a `PassManager` in a similar way to LLVM:
1. We then create a generic ShapeInference Function pass that uses operation
casting to access the inferShapes() method. This is an intraprocedural shape
inference pass that executes after function inlining and iterates over
operations in a worklist calling inferShapes for each operation with unknown
result shapes.
```c++
mlir::PassManager pm(ctx);
pm.addPass(mlir::createCanonicalizerPass());
pm.run(&module);
2. Finally, we call into shape inference pass by adding it to the pass manager
for toy:
```Toy(.cpp)
pm.addPass(mlir::createShapeInferencePass());
```
Finally, we can try to run `toyc test/transpose_transpose.toy -emit=mlir -opt`
and observe our pattern in action:
```MLIR(.mlir)
func @transpose_transpose(%arg0: !toy<"array">)
attributes {toy.generic: true} {
%0 = "toy.transpose"(%arg0) : (!toy<"array">) -> !toy<"array">
"toy.return"(%arg0) : (!toy<"array">) -> ()
}
```
As expected we now directly return the function argument, bypassing any
transpose operation. However one of the transpose hasn't been eliminated. That
is not ideal! What happened is that our pattern replaced the last transform with
the function input and left behind the now dead transpose input. The
Canonicalizer knows to cleanup dead operations, however MLIR conservatively
assumes that operations may have side-effects. We can fix it by adding a new
trait, `HasNoSideEffect`, to our `TransposeOp`:
```c++
class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
```
Let's retry now `toyc test/transpose_transpose.toy -emit=mlir -opt`:
```MLIR(.mlir)
func @transpose_transpose(%arg0: !toy<"array">)
attributes {toy.generic: true} {
"toy.return"(%arg0) : (!toy<"array">) -> ()
}
```
Perfect! No `transpose` operation is left, the code is optimal.
The code in `mlir/ToyCombine.cpp` implements a few more patterns that eliminate
trivial reshapes, or fold them into constants.
# Shape Inference and Generic Function Specialization
Our IR operates on generic arrays, we don't know the shape of the arrays other
than during initialization of constants. However we can propagate the shapes
through the computation until they are all known. The issue is how to handle
calls to user-defined generic functions: every call site could deduce different
shapes. One possibility would be to perform symbolic inference based on the
argument types, but this would be hard to generalize if we were to introduce
more control flow in the language. Instead we will proceed by function
specialization: for every call site with new argument shapes we duplicate the
function and specialize it. This is akin to C++ template instantiation:
```
template<int M1, int N1, int M2, int N2>
auto multiply_add(array<M1, N1> a, array<M1, N1> b) {
auto prod = mul(a, b);
auto sum = add(prod, a);
return sum;
}
```
Every new call to `multiply_add` would instantiate the template and emit code
for the specific shape and deduce the return type. Clang implements this
transformation on its AST, but we will implement it in an MLIR pass here.
The ShapeInferencePass is a `ModulePass`: it will run on the Module as a whole.
MLIR also supports `FunctionPass`es which are restricted to modify a single
function at a time. This pass couldn't be a function pass due the nature of its
interprocedural transformations.
Implementing such a pass is done by creating a class inheriting from
`mlir::ModulePass` and overriding the `runOnModule()` method:
```
class ShapeInferencePass : public mlir::ModulePass<ShapeInferencePass> {
void runOnModule() override {
auto &module = getModule();
...
```
The algorithm has two levels, first intra-procedurally:
1. Build a worklist containing all the operations that are returning a generic
Toy array: these are the operations that need shape inference.
2. Iterate on the worklist:
- find an operation to process: the next ready operation in the worklist
has all of its arguments non-generic,
- if no operation is found, break out of the loop,
- remove the operation from the worklist,
- infer the shape of its output from the arguments type.
3. If the worklist is empty, the algorithm succeeded and we infer the return
type for the function from the return operation.
There is a twist though: when a call to a generic function is encountered, shape
inference requires the return type of the callee to be inferred first. At this
point we need to specialize the callee by cloning it. Here is the
inter-procedural flow that wraps the intra-procedural inference:
1. Keep a worklist of function to process. Start with function "main".
2. While the worklist isn't empty:
- Take the last inserted function in the worklist.
- Run the intra-procedural shape inference on this function.
- If the intra-procedural shape inference can't complete, it returns a
FuncOp that needs to be inferred first. In this case, queue this new
function and continue. Otherwise the inference succeeded and we can pop
from the queue.
The full code is in `mlir/ShapeInferencePass.cpp`.
# Future Work: Optimizing Buffer Allocation?
Toy is value-based. Naively this is a lot of allocation, what if we want to
statically optimize placement? What is the right abstraction level to perform
buffer assignment?
** Insert example here **

View File

@ -141,6 +141,9 @@ std::unique_ptr<OpPassBase<FuncOp>> createStripDebugInfoPass();
/// Creates a pass which tests loop fusion utilities.
std::unique_ptr<OpPassBase<FuncOp>> createTestLoopFusionPass();
/// Creates a pass which inlines calls and callable operations as defined by the
/// CallGraph.
std::unique_ptr<Pass> createInlinerPass();
} // end namespace mlir
#endif // MLIR_TRANSFORMS_PASSES_H

View File

@ -291,4 +291,8 @@ struct InlinerPass : public OperationPass<InlinerPass> {
};
} // end anonymous namespace
std::unique_ptr<Pass> mlir::createInlinerPass() {
return std::make_unique<InlinerPass>();
}
static PassRegistration<InlinerPass> pass("inline", "Inline function calls");

View File

@ -10,7 +10,7 @@ def main() {
# Define a variable `a` with shape <2, 3>, initialized with the literal value.
# The shape is inferred from the supplied literal.
var a = [[1, 2, 3], [4, 5, 6]];
# b is identical to a, the literal array is implicitely reshaped: defining new
# b is identical to a, the literal array is implicitly reshaped: defining new
# variables is the way to reshape arrays (element count must match).
var b<2, 3> = [1, 2, 3, 4, 5, 6];
# This call will specialize `multiply_transpose` with <2, 3> for both

View File

@ -13,20 +13,19 @@ def main() {
print(d);
}
# CHECK-LABEL: func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array)
# CHECK-NEXT: attributes {toy.generic = true} {
# CHECK-NEXT: %0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array
# CHECK-NEXT: %1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array
# CHECK-NEXT: "toy.return"(%1) : (!toy.array) -> ()
# CHECK-NEXT: }
# CHECK-LABEL: func @main() {
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> !toy.array<2, 3>
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<2, 3>) -> !toy.array<2, 3>
# CHECK-NEXT: %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> !toy.array<6>
# CHECK-NEXT: %3 = "toy.reshape"(%2) : (!toy.array<6>) -> !toy.array<2, 3>
# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
# CHECK-NEXT: "toy.print"(%5) : (!toy.array) -> ()
# CHECK-NEXT: "toy.return"() : () -> ()
# CHECK-LABEL: func @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>)
# CHECK-NEXT: attributes {toy.generic} {
# CHECK-NEXT: [[VAL_2:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = "toy.mul"([[VAL_0]], [[VAL_2]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
# CHECK-NEXT: "toy.return"([[VAL_3]]) : (tensor<*xf64>) -> ()
# CHECK-LABEL: func @main() {
# CHECK-NEXT: [[VAL_4:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
# CHECK-NEXT: [[VAL_5:%.*]] = "toy.reshape"([[VAL_4]]) : (tensor<2x3xf64>) -> tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = "toy.reshape"([[VAL_6]]) : (tensor<6xf64>) -> tensor<2x3xf64>
# CHECK-NEXT: [[VAL_8:%.*]] = "toy.generic_call"([[VAL_5]], [[VAL_7]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
# CHECK-NEXT: [[VAL_9:%.*]] = "toy.generic_call"([[VAL_7]], [[VAL_5]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
# CHECK-NEXT: "toy.print"([[VAL_9]]) : (tensor<*xf64>) -> ()
# CHECK-NEXT: "toy.return"() : () -> ()

View File

@ -1,11 +1,9 @@
// RUN: not toyc-ch4 %s -emit=mlir 2>&1
// This IR is not "valid":
// The following IR is not "valid":
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
// This all round-trip since this is opaque for MLIR.
func @main() {
%0 = "toy.print"() : () -> !toy.array<2, 3>
%0 = "toy.print"() : () -> tensor<2x3xf64>
}

View File

@ -6,9 +6,9 @@ def main() {
}
# CHECK-LABEL: func @main() {
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<1xf64>} : () -> !toy.array<1>
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<1>) -> !toy.array<2, 2>
# CHECK-NEXT: "toy.print"(%1) : (!toy.array<2, 2>) -> ()
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<f64>} : () -> tensor<f64>
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (tensor<f64>) -> tensor<2x2xf64>
# CHECK-NEXT: "toy.print"(%1) : (tensor<2x2xf64>) -> ()
# CHECK-NEXT: "toy.return"() : () -> ()
# CHECK-NEXT: }

View File

@ -1,19 +0,0 @@
# RUN: toyc-ch4 %s -emit=mlir 2>&1 | FileCheck %s
# RUN: toyc-ch4 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT
def transpose_transpose(x) {
return transpose(transpose(x));
}
def main() {
print(transpose_transpose([[1, 2], [3, 4]]));
}
#CHECK-LABEL: func @transpose_transpose
#CHECK: transpose
#CHECK-LABEL: main
#OPT-LABEL: func @transpose_transpose
#OPT-NOT: transpose

View File

@ -1,24 +0,0 @@
# RUN: toyc-ch4 %s -emit=mlir 2>&1 | FileCheck %s
# RUN: toyc-ch4 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT
# We expect no reshape in this function with optimizations enabled
def foo(a) {
var b<2,1> = a;
var c<2,1> = b;
print(c);
}
def main() {
var a<2, 1> = [1, 2];
foo(a);
}
# without optimizations, match the reshape
#CHECK-LABEL: func @foo
#CHECK: reshape
#CHECK-LABEL: main
# with optimizations, ensure no reshape
#OPT-LABEL: main
#OPT-LABEL: func @foo_2x1
#OPT-NOT: reshape