Add Ch.5 of the toy tutorial.

This chapter adds a partial lowering of toy operations, all but PrintOp, to a combination of the Affine and Std dialects. This chapter focuses on introducing the conversion framework, the benefits of partial lowering, and how easily dialects may co-exist in the IR.

PiperOrigin-RevId: 275150649
This commit is contained in:
River Riddle 2019-10-16 17:33:34 -07:00 committed by A. Unique TensorFlower
parent 7045471913
commit 1ba9bb0507
36 changed files with 1961 additions and 2686 deletions

View File

@ -55,7 +55,8 @@ static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
ConstantOp::build(builder, state, dataType, dataAttribute);
}
/// Verifier for constant operation.
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
/// in the op definition.
static mlir::LogicalResult verify(ConstantOp op) {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
@ -63,6 +64,8 @@ static mlir::LogicalResult verify(ConstantOp op) {
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = op.value().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return op.emitOpError(
@ -70,7 +73,9 @@ static mlir::LogicalResult verify(ConstantOp op) {
"attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
}
for (int dim = 0; dim < attrType.getRank(); ++dim) {
// Check that each of the dimensions match between the two types.
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
return op.emitOpError(
"return type shape mismatches its attribute at dimension ")

View File

@ -118,6 +118,8 @@ static mlir::LogicalResult verify(ConstantOp op) {
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = op.value().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return op.emitOpError(
@ -125,7 +127,9 @@ static mlir::LogicalResult verify(ConstantOp op) {
"attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
}
for (int dim = 0; dim < attrType.getRank(); ++dim) {
// Check that each of the dimensions match between the two types.
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
return op.emitOpError(
"return type shape mismatches its attribute at dimension ")

View File

@ -79,7 +79,7 @@ public:
// the structural properties of the IR and invoke any specific verifiers we
// have on the Toy operations.
if (failed(mlir::verify(theModule))) {
theModule.emitError("Module verification error");
theModule.emitError("module verification error");
return nullptr;
}
@ -229,7 +229,7 @@ private:
if (auto *variable = symbolTable.lookup(expr.getName()))
return variable;
emitError(loc(expr.loc()), "Error: unknown variable '")
emitError(loc(expr.loc()), "error: unknown variable '")
<< expr.getName() << "'";
return nullptr;
}
@ -289,7 +289,8 @@ private:
auto dataAttribute =
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
// Build the MLIR op `toy.constant`.
// Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
// method.
return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
}
@ -389,7 +390,7 @@ private:
auto init = vardecl.getInitVal();
if (!init) {
emitError(loc(vardecl.loc()),
"Missing initializer in variable declaration");
"missing initializer in variable declaration");
return nullptr;
}

View File

@ -1,40 +1,42 @@
add_subdirectory(include)
set(LLVM_LINK_COMPONENTS
Support
)
set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include")
add_public_tablegen_target(ToyCh5CombineIncGen)
add_toy_chapter(toyc-ch5
toyc.cpp
parser/AST.cpp
mlir/EarlyLowering.cpp
mlir/LateLowering.cpp
mlir/MLIRGen.cpp
mlir/Dialect.cpp
mlir/DeadFunctionEliminationPass.cpp
mlir/LowerToAffineLoops.cpp
mlir/ShapeInferencePass.cpp
mlir/ToyDialect.cpp
mlir/ToyCombine.cpp
)
add_dependencies(toyc-ch5 ToyCh5ShapeInferenceInterfaceIncGen)
add_dependencies(toyc-ch5 ToyCh5OpsIncGen)
add_dependencies(toyc-ch5 ToyCh5CombineIncGen)
add_dependencies(toyc-ch5 MLIRCallOpInterfacesIncGen)
include_directories(include/)
include_directories(../../Linalg/Linalg1/include/)
include_directories(../../Linalg/Linalg2/include/)
include_directories(../../Linalg/Linalg3/include/)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
target_link_libraries(toyc-ch5
PRIVATE
Linalg3DialectConstruction
Linalg3
Linalg2
Linalg1
MLIRAffineOps
MLIRAnalysis
MLIREDSC
MLIRExecutionEngine
MLIRIR
MLIRLLVMIR
MLIRParser
MLIRPass
MLIRTargetLLVMIR
MLIRTransforms
MLIRSupport
)
MLIRStandardOps
MLIRTransforms)
whole_archive_link(toyc-ch5
MLIRAffineOps
MLIRStandardOps
)
)

View File

@ -0,0 +1 @@
add_subdirectory(toy)

View File

@ -33,10 +33,9 @@
namespace toy {
/// A variable
/// A variable type with shape information.
struct VarType {
enum { TY_FLOAT, TY_INT } elt_ty;
std::vector<int> shape;
std::vector<int64_t> shape;
};
/// Base class for all expression nodes.
@ -50,9 +49,7 @@ public:
Expr_Var,
Expr_BinOp,
Expr_Call,
Expr_Print, // builtin
Expr_If,
Expr_For,
Expr_Print,
};
ExprAST(ExprASTKind kind, Location location)
@ -85,7 +82,7 @@ public:
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
};
///
/// Expression class for a literal value.
class LiteralExprAST : public ExprAST {
std::vector<std::unique_ptr<ExprAST>> values;
std::vector<int64_t> dims;
@ -116,7 +113,7 @@ public:
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
};
///
/// Expression class for defining a variable.
class VarDeclExprAST : public ExprAST {
std::string name;
VarType type;
@ -136,7 +133,7 @@ public:
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
};
///
/// Expression class for a return operator.
class ReturnExprAST : public ExprAST {
llvm::Optional<std::unique_ptr<ExprAST>> expr;

View File

@ -0,0 +1,9 @@
set(LLVM_TARGET_DEFINITIONS Ops.td)
mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
add_public_tablegen_target(ToyCh5OpsIncGen)
set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(ToyCh5ShapeInferenceInterfaceIncGen)

View File

@ -16,7 +16,7 @@
// =============================================================================
//
// This file implements the IR Dialect for the Toy language.
// See g3doc/Tutorials/Toy/Ch-3.md for more information.
// See g3doc/Tutorials/Toy/Ch-2.md for more information.
//
//===----------------------------------------------------------------------===//
@ -25,369 +25,31 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "toy/ShapeInferenceInterface.h"
namespace mlir {
class Builder;
}
namespace toy {
/// This is the definition of the Toy dialect. A dialect inherits from
/// mlir::Dialect and register custom operations and types (in its constructor).
/// It can also overriding general behavior of dialects exposed as virtual
/// method, for example regarding verification and parsing/printing.
/// mlir::Dialect and registers custom attributes, operations, and types (in its
/// constructor). It can also override some general behavior exposed via virtual
/// methods.
class ToyDialect : public mlir::Dialect {
public:
explicit ToyDialect(mlir::MLIRContext *ctx);
/// Parse a type registered to this dialect. Overriding this method is
/// required for dialects that have custom types.
/// Technically this is only needed to be able to round-trip to textual IR.
mlir::Type parseType(llvm::StringRef tyData,
mlir::Location loc) const override;
/// Print a type registered to this dialect. Overriding this method is
/// only required for dialects that have custom types.
/// Technically this is only needed to be able to round-trip to textual IR.
void printType(mlir::Type type, llvm::raw_ostream &os) const override;
/// Provide a utility accessor to the dialect namespace. This is used by
/// several utilities for casting between dialects.
static llvm::StringRef getDialectNamespace() { return "toy"; }
};
////////////////////////////////////////////////////////////////////////////////
/////////////////////// Custom Types for the Dialect ///////////////////////////
////////////////////////////////////////////////////////////////////////////////
namespace detail {
struct ToyArrayTypeStorage;
}
/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa.
enum ToyTypeKind {
// The enum starts at the range reserved for this dialect.
TOY_TYPE = mlir::Type::FIRST_TOY_TYPE,
TOY_ARRAY,
};
/// Type for Toy arrays.
/// In MLIR Types are reference to immutable and uniqued objects owned by the
/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued
/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and
/// provides the public facade API to interact with the type.
class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
detail::ToyArrayTypeStorage> {
public:
using Base::Base;
/// Returns the dimensions for this array, or and empty range for a generic
/// array.
llvm::ArrayRef<int64_t> getShape();
/// Predicate to test if this array is generic (shape haven't been inferred
/// yet).
bool isGeneric() { return getShape().empty(); }
/// Return the rank of this array (0 if it is generic).
int getRank() { return getShape().size(); }
/// Return the type of individual elements in the array.
mlir::Type getElementType();
/// Get a MemRef equivalent to this array type.
mlir::MemRefType toMemref();
/// Get the unique instance of this Type from the context.
/// A ToyArrayType is only defined by the shape of the array.
static ToyArrayType get(mlir::MLIRContext *context,
llvm::ArrayRef<int64_t> shape = {});
/// Support method to enable LLVM-style RTTI type casting.
static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
};
////////////////////////////////////////////////////////////////////////////////
//////////////////// Custom Operations for the Dialect /////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Constant operation turns a literal into an SSA value. The data is attached
/// to the operation as an attribute. For example:
///
/// %0 = "toy.constant"()
/// {value: dense<tensor<2x3xf64>, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>}
/// : () -> !toy<"array<2, 3>">
///
/// An operation inherits from `class Op` and specifies optional traits. Here we
/// indicate that `toy.constant` does not have any operands and returns a single
/// result. The traits provide some utilities methods for the operation, for
/// instance we will be able to use `getResult()`, but `getOperand()` won't be
/// available.
class ConstantOp : public mlir::Op<ConstantOp, mlir::OpTrait::ZeroOperands,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
public:
/// This is the name used by MLIR to match an operation to this class during
/// parsing.
static llvm::StringRef getOperationName() { return "toy.constant"; }
/// The operation can have extra verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populates the `state` that MLIR uses to create operations.
/// The `toy.constant` operation does not have arguments but attaches a
/// constant array as an attribute and returns it as an SSA value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
llvm::ArrayRef<int64_t> shape,
mlir::DenseElementsAttr value);
/// Similar to the one above, but takes a single float and returns a
/// !toy<"array<1>">.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::FloatAttr value);
mlir::DenseElementsAttr getValue() {
return getAttr("value").cast<mlir::DenseElementsAttr>();
}
/// Inherit constructor.
using Op::Op;
};
/// Generic calls represent calls to a user defined function that needs to
/// be specialized for the shape of its arguments. The callee name is attached
/// as a literal string as an attribute. The arguments list must match the
/// arguments expected by the callee. For example:
///
/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
///
/// This is only valid if a function named "my_func" exists and takes two
/// arguments.
class GenericCallOp
: public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::OneResult> {
public:
/// MLIR will use this to register the operation with the parser/printer.
static llvm::StringRef getOperationName() { return "toy.generic_call"; }
/// Operations can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to the builder to allow:
/// mlir::Builder::create<GenericCallOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.generic_call` operation accepts a callee name and a list of
/// arguments for the call.
static void build(mlir::Builder *builder, mlir::OperationState &state,
llvm::StringRef callee,
llvm::ArrayRef<mlir::Value *> arguments);
/// Return the name of the callee.
llvm::StringRef getCalleeName();
/// Inherit constructor.
using Op::Op;
};
/// Return operations terminate blocks (and functions as well). They take a
/// single argument and the type must match the function return type.
class ReturnOp
: public mlir::Op<ReturnOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::ZeroResult, mlir::OpTrait::IsTerminator> {
public:
static llvm::StringRef getOperationName() { return "toy.return"; }
/// Operations can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.return` operation accepts an optional single array as an argument
/// and does not have any returned value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value = nullptr);
/// Return true if there is a returned value.
bool hasOperand() { return 0 != getNumOperands(); }
/// Helper to return the optional operand. Caller must check if the operand
/// is present before calling this.
mlir::Value *getOperand() { return getOperation()->getOperand(0); }
/// Inherit constructor.
using Op::Op;
};
/// The print builtin takes a single array argument and does not return any.
class PrintOp : public mlir::Op<PrintOp, mlir::OpTrait::OneOperand,
mlir::OpTrait::ZeroResult> {
public:
static llvm::StringRef getOperationName() { return "toy.print"; }
/// Operations can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.print` operation accepts a single array as argument and does
/// not have any returned value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value);
/// Inherit constructor.
using Op::Op;
};
class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
public:
static llvm::StringRef getOperationName() { return "toy.transpose"; }
/// Operation can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<TransposeOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.transpose` operation accepts a single array as argument and
/// returns the transposed array as its only result.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value);
// Register our patterns for rewrite by the Canonicalization framework.
static void
getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
mlir::MLIRContext *context);
/// Inherit constructor.
using Op::Op;
};
/// Reshape operation is transforming its input array into a new array with the
/// same number of elements but different shapes. For example:
///
/// %0 = "toy.transpose"(%arg1) : (!toy<"array<10>">) -> !toy<"array<5, 2>">
///
class ReshapeOp : public mlir::Op<ReshapeOp, mlir::OpTrait::OneOperand,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
public:
static llvm::StringRef getOperationName() { return "toy.reshape"; }
/// Operation can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<ReshapeOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.reshape` operation accepts a single array as argument and
/// returns the array with the specified reshapedType as its only result.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value, ToyArrayType reshapedType);
// Register our patterns for rewrite by the Canonicalization framework.
static void
getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
mlir::MLIRContext *context);
/// Inherit constructor.
using Op::Op;
};
/// Binary operation implementing a multiplication. For two-dimensional array
/// a matrix multiplication is implemented, while for one dimensional array a
/// dot product is performed.
class MulOp : public mlir::Op<MulOp, mlir::OpTrait::NOperands<2>::Impl,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
public:
static llvm::StringRef getOperationName() { return "toy.mul"; }
/// Operation can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.mul` operation accepts two operands as argument and returns
/// a single value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs);
/// Convenience accessor for LHS of the expression.
mlir::Value *getLHS() { return getOperand(0); }
/// Convenience accessor for RHS of the expression.
mlir::Value *getRHS() { return getOperand(1); }
/// Inherit constructor.
using Op::Op;
};
/// Element wise addition of two arrays. The shape must match.
class AddOp : public mlir::Op<AddOp, mlir::OpTrait::NOperands<2>::Impl,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
public:
static llvm::StringRef getOperationName() { return "toy.add"; }
/// Operation can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.mul` operation accepts two operands as argument and returns
/// a single value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs);
/// Convenience accessor for LHS of the expression.
mlir::Value *getLHS() { return getOperand(0); }
/// Convenience accessor for RHS of the expression.
mlir::Value *getRHS() { return getOperand(1); }
/// Inherit constructor.
using Op::Op;
};
/// AllocOp is a temporary operation for buffer allocation, created as part of
/// partial lowering.
class AllocOp : public mlir::Op<AllocOp, mlir::OpTrait::ZeroOperands,
mlir::OpTrait::OneResult> {
public:
static llvm::StringRef getOperationName() { return "toy.alloc"; }
/// Interface to mlir::Builder::create<AllocOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// `toy.alloc` does not have any argument and returns a toy array.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Type retType);
/// Inherit constructor.
using Op::Op;
};
/// FIXME: should be in std?
class TypeCastOp : public mlir::Op<TypeCastOp, mlir::OpTrait::OneOperand,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
public:
static llvm::StringRef getOperationName() { return "toy.cast"; }
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value, mlir::Type destTy);
// Register our patterns for rewrite by the Canonicalization framework.
static void
getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
mlir::MLIRContext *context);
/// Inherit constructor.
using Op::Op;
};
/// Include the auto-generated header file containing the declarations of the
/// toy operations.
#define GET_OP_CLASSES
#include "toy/Ops.h.inc"
} // end namespace toy
} // end namespace mlir
#endif // MLIR_TUTORIAL_TOY_DIALECT_H_

View File

@ -31,7 +31,7 @@ namespace toy {
/// Structure definition a location in a file.
struct Location {
std::shared_ptr<std::string> file; ///< filename
std::shared_ptr<std::string> file; ///< filename.
int line; ///< line number.
int col; ///< column number.
};

View File

@ -1,45 +0,0 @@
//===- Lowering.h - Lexer for the Toy language ----------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file exposes the interface to the lowering for Toy. It is divided in
// two parts: an *early lowering* that emits operations in the `Linalg`
// dialects for a subset of the Toy IR, and a *late lowering* that materializes
// buffers and converts all operations and type to the LLVM dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_EXAMPLES_TOY_LOWERING_H_
#define MLIR_EXAMPLES_TOY_LOWERING_H_
#include <memory>
namespace mlir {
class Pass;
class DialectConversion;
} // namespace mlir
namespace toy {
/// Create a pass for lowering operations in the `Linalg` dialects, for a subset
/// of the Toy IR (matmul).
std::unique_ptr<mlir::Pass> createEarlyLoweringPass();
/// Create a pass for the late lowering toward LLVM dialect.
std::unique_ptr<mlir::Pass> createLateLoweringPass();
} // namespace toy
#endif // MLIR_EXAMPLES_TOY_LOWERING_H_

View File

@ -0,0 +1,272 @@
//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Defines the operations of the Toy dialect.
//
//===----------------------------------------------------------------------===//
#ifdef TOY_OPS
#else
#define TOY_OPS
#ifdef MLIR_CALLINTERFACES
#else
include "mlir/Analysis/CallInterfaces.td"
#endif // MLIR_CALLINTERFACES
#ifdef SHAPE_INFERENCE_INTERFACE
#else
include "toy/ShapeInferenceInterface.td"
#endif // SHAPE_INFERENCE_INTERFACE
// Provide a definition of the 'toy' dialect in the ODS framework so that we
// can define our operations.
def Toy_Dialect : Dialect {
let name = "toy";
let cppNamespace = "toy";
}
// Base class for toy dialect operations. This operation inherits from the base
// `Op` class in OpBase.td, and provides:
// * The parent dialect of the operation.
// * The mnemonic for the operation, or the name without the dialect prefix.
// * A list of traits for the operation.
class Toy_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Toy_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// Toy Operations
//===----------------------------------------------------------------------===//
// We define a toy operation by inherting from our base 'Toy_Op' class above.
// Here we provide the mnemonic and a list of traits for the operation. The
// constant operation is marked as 'NoSideEffect' as it is a pure operation
// and may be removed if dead.
def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
// Provide a summary and description for this operation. This can be used to
// auto-generate documenatation of the operations within our dialect.
let summary = "constant";
let description = [{
Constant operation turns a literal into an SSA value. The data is attached
to the operation as an attribute. For example:
```mlir
%0 = "toy.constant"()
{ value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
: () -> tensor<2x3xf64>
```
}];
// The constant operation takes an attribute as the only input.
let arguments = (ins F64ElementsAttr:$value);
// The constant operation returns a single value of TensorType.
let results = (outs F64Tensor);
// Add custom build methods for the constant operation. These method populates
// the `state` that MLIR uses to create operations, i.e. these are used when
// using `builder.create<ConstantOp>(...)`.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
"DenseElementsAttr value", [{
build(builder, result, value.getType(), value);
}]>,
// Build a constant with a given constant floating-point value.
OpBuilder<"Builder *builder, OperationState &result, double value", [{
buildConstantOp(builder, result, value);
}]>
];
// Invoke a static verify method to verify this constant operation.
let verifier = [{ return ::verify(*this); }];
}
def AddOp : Toy_Op<"add",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise addition operation";
let description = [{
The "add" operation performs element-wise addition between two tensors.
The shapes of the tensor operands are expected to match.
}];
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
let results = (outs F64Tensor);
// Allow building an AddOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildAddOp(b, result, lhs, rhs);
}]
>];
}
def CastOp : Toy_Op<"cast",
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
SameOperandsAndResultShape]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
without changing any data elements. The source and destination types
must both be tensor types with the same element type. If both are ranked
then the rank should be the same and static dimensions should match. The
operation is invalid if converting to a mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
// Set the folder bit so that we can fold redundant cast operations.
let hasFolder = 1;
}
def GenericCallOp : Toy_Op<"generic_call",
[DeclareOpInterfaceMethods<CallOpInterface>]> {
let summary = "generic call operation";
let description = [{
Generic calls represent calls to a user defined function that needs to
be specialized for the shape of its arguments. The callee name is attached
as a symbol reference via an attribute. The arguments list must match the
arguments expected by the callee. For example:
```mlir
%4 = "toy.generic_call"(%1, %3) {callee = @my_func}
: (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
```
This is only valid if a function named "my_func" exists and takes two
arguments.
}];
// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
// The generic call operation returns a single value of TensorType.
let results = (outs F64Tensor);
// Add custom build methods for the generic call operation.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
"StringRef callee, ArrayRef<Value *> arguments", [{
buildGenericCallOp(builder, result, callee, arguments);
}]>
];
}
def MulOp : Toy_Op<"mul",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise multiplication operation";
let description = [{
The "mul" operation performs element-wise multiplication between two
tensors. The shapes of the tensor operands are expected to match.
}];
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
let results = (outs F64Tensor);
// Allow building a MulOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildMulOp(b, result, lhs, rhs);
}]
>];
}
def PrintOp : Toy_Op<"print"> {
let summary = "print operation";
let description = [{
The "print" builtin operation prints a given input tensor, and produces
no results.
}];
// The print operation takes an input tensor to print.
// We also allow a F64MemRef to enable interop during partial lowering.
let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
}
def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> {
let summary = "tensor reshape operation";
let description = [{
Reshape operation is transforming its input tensor into a new tensor with
the same number of elements but different shapes. For example:
```mlir
%0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64>
```
}];
let arguments = (ins F64Tensor:$input);
let hasCanonicalizer = 1;
// We expect that the reshape operation returns a statically shaped tensor.
let results = (outs StaticShapeTensorOf<[F64]>);
}
def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
let summary = "return operation";
let description = [{
The "return" operation represents a return operation within a function.
The operation takes an optional tensor operand and produces no results.
The operand type must match the signature of the function that contains
the operation. For example:
```mlir
func @foo() -> tensor<2xf64> {
...
toy.return %0 : tensor<2xf64>
}
```
}];
// The return operation takes an optional input operand to return. This
// value must match the return type of the enclosing function.
let arguments = (ins Variadic<F64Tensor>:$input);
// Allow building a ReturnOp with no return operand.
let builders = [OpBuilder<
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
>];
// Provide extra utility definitions on the c++ operation class definition.
let extraClassDeclaration = [{
bool hasOperand() { return getNumOperands() != 0; }
}];
// Invoke a static verify method to verify this return operation.
let verifier = [{ return ::verify(*this); }];
}
def TransposeOp : Toy_Op<"transpose",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "transpose operation";
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor);
let hasCanonicalizer = 1;
// Allow building a TransposeOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
buildTransposeOp(b, result, input);
}]
>];
}
#endif // TOY_OPS

View File

@ -26,10 +26,16 @@
namespace mlir {
class Pass;
} // namespace mlir
namespace toy {
std::unique_ptr<mlir::Pass> createShapeInferencePass();
} // namespace toy
std::unique_ptr<Pass> createDeadFunctionEliminationPass();
std::unique_ptr<Pass> createShapeInferencePass();
/// Create a pass for lowering to operations in the `Affine` and `Std` dialects,
/// for a subset of the Toy IR (e.g. matmul).
std::unique_ptr<mlir::Pass> createLowerToAffinePass();
} // end namespace toy
} // end namespace mlir
#endif // MLIR_TUTORIAL_TOY_PASSES_H

View File

@ -0,0 +1,37 @@
//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file contains the declarations of the shape inference interfaces defined
// in ShapeInferenceInterface.td.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
#include "mlir/IR/OpDefinition.h"
namespace mlir {
namespace toy {
/// Include the auto-generated declarations.
#include "toy/ShapeInferenceOpInterfaces.h.inc"
} // end namespace toy
} // end namespace mlir
#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_

View File

@ -0,0 +1,38 @@
//===- ShapeInferenceInterface.td - Operation Interface for Shape Inference ----------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Defines the operations of the Shape Inference Op Interface.
//
//===----------------------------------------------------------------------===//
#ifdef SHAPE_INFERENCE_INTERFACE
#else
#define SHAPE_INFERENCE_INTERFACE
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.",
"void", "inferShapes">
];
}
#endif // SHAPE_INFERENCE_INTERFACE

View File

@ -0,0 +1,68 @@
//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a Module level pass performing dead function
// elimination. This is required as a post-processing step after function
// inlining.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Verifier.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "toy/Passes.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
namespace {
/// This is a simple function DCE pass that deletes all non-main functions after
/// inlining.
/// TODO(riverriddle) This is only necessary because MLIR currently does not
/// have generic DCE support for functions.
class DeadFunctionEliminationPass
: public mlir::ModulePass<DeadFunctionEliminationPass> {
public:
void runOnModule() override {
mlir::ModuleOp module = getModule();
mlir::SymbolTable moduleSymTable(module);
// Eliminate non-main functions.
auto mainFn = moduleSymTable.lookup<mlir::FuncOp>("main");
for (mlir::FuncOp func :
llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
if (func != mainFn)
func.erase();
}
}
};
} // end anonymous namespace
/// Create a pass that eliminates inlined functions in toy.
std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() {
return std::make_unique<DeadFunctionEliminationPass>();
}

View File

@ -0,0 +1,256 @@
//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the dialect for the Toy IR: custom type parsing and
// operation verification.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
using namespace mlir::toy;
//===----------------------------------------------------------------------===//
// ToyInlinerInterface
//===----------------------------------------------------------------------===//
/// This class defines the interface for handling inlining with Toy
/// operations.
struct ToyInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
/// All operations within toy can be inlined.
bool isLegalToInline(Operation *, Region *,
BlockAndValueMapping &) const final {
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
/// Handle the given inlined terminator(toy.return) by replacing it with a new
/// operation as necessary.
void handleTerminator(Operation *op,
ArrayRef<Value *> valuesToRepl) const final {
// Only "toy.return" needs to be handled here.
auto returnOp = cast<ReturnOp>(op);
// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
}
/// Attempts to materialize a conversion for a type mismatch between a call
/// from this dialect, and a callable region. This method should generate an
/// operation that takes 'input' as the only operand, and produces a single
/// result of 'resultType'. If a conversion can not be generated, nullptr
/// should be returned.
Operation *materializeCallConversion(OpBuilder &builder, Value *input,
Type resultType,
Location conversionLoc) const final {
return builder.create<CastOp>(conversionLoc, resultType, input);
}
};
//===----------------------------------------------------------------------===//
// ToyDialect
//===----------------------------------------------------------------------===//
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
>();
addInterfaces<ToyInlinerInterface>();
}
//===----------------------------------------------------------------------===//
// Toy Operations
//===----------------------------------------------------------------------===//
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
double value) {
auto dataType = builder->getTensorType({}, builder->getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
ConstantOp::build(builder, state, dataType, dataAttribute);
}
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
/// in the op definition.
static mlir::LogicalResult verify(ConstantOp op) {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = op.value().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return op.emitOpError(
"return type must match the one of the attached value "
"attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
}
// Check that each of the dimensions match between the two types.
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
return op.emitOpError(
"return type shape mismatches its attribute at dimension ")
<< dim << ": " << attrType.getShape()[dim]
<< " != " << resultType.getShape()[dim];
}
}
return mlir::success();
}
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(builder->getTensorType(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
/// Infer the output shape of the AddOp, this is required by the shape inference
/// interface.
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
static void buildGenericCallOp(mlir::Builder *builder,
mlir::OperationState &state, StringRef callee,
ArrayRef<mlir::Value *> arguments) {
// Generic call always returns an unranked Tensor initially.
state.addTypes(builder->getTensorType(builder->getF64Type()));
state.addOperands(arguments);
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
}
/// Return the callee of the generic call operation, this is required by the
/// call interface.
CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return getAttrOfType<SymbolRefAttr>("callee");
}
/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(builder->getTensorType(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
/// Infer the output shape of the MulOp, this is required by the shape inference
/// interface.
void MulOp::inferShapes() {
auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
auto lhsRank = lhs.getShape().size();
auto rhsRank = rhs.getShape().size();
if (lhsRank != rhsRank)
return;
SmallVector<int64_t, 2> dims;
if (lhsRank == 1) {
// dot product, result shape is <1>
dims.push_back(1);
} else if (lhsRank == 2) {
dims.push_back(lhs.getShape()[0]);
dims.push_back(rhs.getShape()[1]);
} else {
return;
}
getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
}
static mlir::LogicalResult verify(ReturnOp op) {
// We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition.
auto function = cast<FuncOp>(op.getParentOp());
/// ReturnOps can only have a single optional operand.
if (op.getNumOperands() > 1)
return op.emitOpError() << "expects at most 1 return operand";
// The operand number and types must match the function signature.
const auto &results = function.getType().getResults();
if (op.getNumOperands() != results.size())
return op.emitOpError()
<< "does not return the same number of values ("
<< op.getNumOperands() << ") as the enclosing function ("
<< results.size() << ")";
// If the operation does not have an input, we are done.
if (!op.hasOperand())
return mlir::success();
auto inputType = *op.operand_type_begin();
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
resultType.isa<mlir::UnrankedTensorType>())
return mlir::success();
return op.emitError() << "type of return operand ("
<< *op.operand_type_begin()
<< ") doesn't match function result type ("
<< results.front() << ")";
}
static void buildTransposeOp(mlir::Builder *builder,
mlir::OperationState &state, mlir::Value *value) {
state.addTypes(builder->getTensorType(builder->getF64Type()));
state.addOperands(value);
}
void TransposeOp::inferShapes() {
SmallVector<int64_t, 2> dims;
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
dims.insert(dims.end(), arrayTy.getShape().begin(), arrayTy.getShape().end());
if (dims.size() == 2)
std::swap(dims[0], dims[1]);
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "toy/Ops.cpp.inc"

View File

@ -1,148 +0,0 @@
//=======- EarlyLowering.cpp - Toy Lowering to Linear Algebra Dialect -=======//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements early lowering of Toy IR to Linalg Dialect: we only
// lower the computationally intensive part of the program (matmul...) to a
// dialect specialized for optimizations.
//
// This is intended to showcase how multiple dialects can cohabit in the same
// function. After this lowering, you would still have toy.print in the IR for
// example.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "linalg1/Dialect.h"
#include "linalg1/Intrinsics.h"
#include "linalg1/ViewOp.h"
#include "linalg3/TensorOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"
#include <algorithm>
using namespace mlir;
namespace {
/// Utility function for type casting: this is making the type checker happy,
/// while delaying the actual work involved to convert the type. Most of the
/// time both side of the cast (producer and consumer) will be lowered to a
/// dialect like LLVM and end up with the same LLVM representation, at which
/// point this becomes a no-op and is eliminated.
Value *typeCast(ConversionPatternRewriter &builder, Value *val, Type destTy) {
if (val->getType() == destTy)
return val;
return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy)
.getResult();
}
/// Create a type cast to turn a toy.array into a memref. The Toy Array will be
/// lowered to a memref during buffer allocation, at which point the type cast
/// becomes useless.
Value *memRefTypeCast(ConversionPatternRewriter &builder, Value *val) {
if (val->getType().isa<MemRefType>())
return val;
auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>();
if (!toyArrayTy)
return val;
return typeCast(builder, val, toyArrayTy.toMemref());
}
/// Lower toy.mul to Linalg `matmul`.
///
/// This class inherit from `ConversionPattern` and override `rewrite`,
/// similarly to the PatternRewriter introduced in the previous chapter.
/// It will be called by the DialectConversion framework (see `LateLowering`
/// class below).
class MulOpConversion : public ConversionPattern {
public:
explicit MulOpConversion(MLIRContext *context)
: ConversionPattern(toy::MulOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
using namespace edsc;
using intrinsics::constant_index;
using linalg::intrinsics::range;
using linalg::intrinsics::view;
toy::MulOp mul = cast<toy::MulOp>(op);
auto loc = mul.getLoc();
Value *result = memRefTypeCast(
rewriter, rewriter.create<toy::AllocOp>(loc, mul.getResult()->getType())
.getResult());
Value *lhs = memRefTypeCast(rewriter, operands[0]);
auto memrefLHSTy = lhs->getType().cast<MemRefType>();
Value *rhs = memRefTypeCast(rewriter, operands[1]);
auto memrefRHSTy = rhs->getType().cast<MemRefType>();
mlir::edsc::ScopedContext scope(rewriter, loc);
edsc::ValueHandle r0 =
range(constant_index(0), constant_index(memrefLHSTy.getDimSize(0)),
constant_index(1));
edsc::ValueHandle r1 =
range(constant_index(0), constant_index(memrefLHSTy.getDimSize(1)),
constant_index(1));
edsc::ValueHandle r2 =
range(constant_index(0), constant_index(memrefRHSTy.getDimSize(1)),
constant_index(1));
auto lhsView = view(lhs, {r0, r1});
auto rhsView = view(rhs, {r1, r2});
auto resultView = view(result, {r0, r2});
rewriter.create<linalg::MatmulOp>(loc, lhsView, rhsView, resultView);
rewriter.replaceOp(op, {typeCast(rewriter, result, mul.getType())});
return matchSuccess();
}
};
/// This is lowering to Linalg the parts that are computationally intensive
/// (like matmul for example...) while keeping the rest of the code in the Toy
/// dialect.
struct EarlyLoweringPass : public FunctionPass<EarlyLoweringPass> {
void runOnFunction() override {
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
OwningRewritePatternList patterns;
patterns.insert<MulOpConversion>(&getContext());
if (failed(applyPartialConversion(getFunction(), target, patterns))) {
emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering Toy\n");
signalPassFailure();
}
}
};
} // end anonymous namespace
namespace toy {
std::unique_ptr<mlir::Pass> createEarlyLoweringPass() {
return std::make_unique<EarlyLoweringPass>();
}
} // namespace toy

View File

@ -1,470 +0,0 @@
//====- LateLowering.cpp - Lowering from Toy+Linalg to LLVM -===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements late lowering of IR mixing Toy and Linalg to LLVM.
// It involves intemerdiate steps:
// -
// - a mix of affine and standard dialect.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "linalg1/Dialect.h"
#include "linalg1/Intrinsics.h"
#include "linalg1/ViewOp.h"
#include "linalg3/ConvertToLLVMDialect.h"
#include "linalg3/TensorOps.h"
#include "linalg3/Transforms.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"
#include <algorithm>
using namespace mlir;
namespace {
/// Utility function for type casting: this is making the type checker happy,
/// while delaying the actual work involved to convert the type. Most of the
/// time both side of the cast (producer and consumer) will be lowered to a
/// dialect like LLVM and end up with the same LLVM representation, at which
/// point this becomes a no-op and is eliminated.
Value *typeCast(PatternRewriter &builder, Value *val, Type destTy) {
if (val->getType() == destTy)
return val;
return builder.create<toy::TypeCastOp>(val->getLoc(), val, destTy)
.getResult();
}
/// Create a type cast to turn a toy.array into a memref. The Toy Array will be
/// lowered to a memref during buffer allocation, at which point the type cast
/// becomes useless.
Value *memRefTypeCast(PatternRewriter &builder, Value *val) {
if (val->getType().isa<MemRefType>())
return val;
auto toyArrayTy = val->getType().dyn_cast<toy::ToyArrayType>();
if (!toyArrayTy)
return val;
return typeCast(builder, val, toyArrayTy.toMemref());
}
/// Lower a toy.add to an affine loop nest.
///
/// This class inherit from `ConversionPattern` and override `rewrite`,
/// similarly to the PatternRewriter introduced in the previous chapter.
/// It will be called by the DialectConversion framework (see `LateLowering`
/// class below).
class AddOpConversion : public ConversionPattern {
public:
explicit AddOpConversion(MLIRContext *context)
: ConversionPattern(toy::AddOp::getOperationName(), 1, context) {}
/// Lower the `op` by generating IR using the `rewriter` builder. The builder
/// is setup with a new function, the `operands` array has been populated with
/// the rewritten operands for `op` in the new function.
/// The results created by the new IR with the builder are returned, and their
/// number must match the number of result of `op`.
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto add = cast<toy::AddOp>(op);
auto loc = add.getLoc();
// Create a `toy.alloc` operation to allocate the output buffer for this op.
Value *result = memRefTypeCast(
rewriter, rewriter.create<toy::AllocOp>(loc, add.getResult()->getType())
.getResult());
Value *lhs = memRefTypeCast(rewriter, operands[0]);
Value *rhs = memRefTypeCast(rewriter, operands[1]);
using namespace edsc;
ScopedContext scope(rewriter, loc);
ValueHandle zero = intrinsics::constant_index(0);
MemRefView vRes(result), vLHS(lhs), vRHS(rhs);
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
IndexHandle i, j, M(vRes.ub(0));
if (vRes.rank() == 1) {
LoopNestBuilder({&i}, {zero}, {M},
{1})([&] { iRes(i) = iLHS(i) + iRHS(i); });
} else {
assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now");
IndexHandle N(vRes.ub(1));
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N},
{1, 1})([&] { iRes(i, j) = iLHS(i, j) + iRHS(i, j); });
}
// Return the newly allocated buffer, with a type.cast to preserve the
// consumers.
rewriter.replaceOp(op, {typeCast(rewriter, result, add.getType())});
return matchSuccess();
}
};
/// Lowers `toy.print` to a loop nest calling `printf` on every individual
/// elements of the array.
class PrintOpConversion : public ConversionPattern {
public:
explicit PrintOpConversion(MLIRContext *context)
: ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
// Get or create the declaration of the printf function in the module.
LLVM::LLVMFuncOp printfFunc = getPrintf(op->getParentOfType<ModuleOp>());
auto print = cast<toy::PrintOp>(op);
auto loc = print.getLoc();
// We will operate on a MemRef abstraction, we use a type.cast to get one
// if our operand is still a Toy array.
Value *operand = memRefTypeCast(rewriter, operands[0]);
Type retTy = printfFunc.getType().getFunctionResultType();
// Create our loop nest now
using namespace edsc;
using extractvalue = intrinsics::ValueBuilder<LLVM::ExtractValueOp>;
using llvmCall = intrinsics::ValueBuilder<LLVM::CallOp>;
ScopedContext scope(rewriter, loc);
ValueHandle zero = intrinsics::constant_index(0);
ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f "));
MemRefView vOp(operand);
IndexedValue iOp(operand);
IndexHandle i, j, M(vOp.ub(0));
auto *dialect = op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto i8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
if (vOp.rank() == 1) {
// clang-format off
LoopBuilder(&i, zero, M, 1)([&]{
llvmCall(retTy,
rewriter.getSymbolRefAttr(printfFunc),
{extractvalue(i8PtrTy, fmtCst, rewriter.getIndexArrayAttr(0)),
iOp(i)});
});
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc),
{extractvalue(i8PtrTy, fmtEol, rewriter.getIndexArrayAttr(0))});
// clang-format on
} else {
IndexHandle N(vOp.ub(1));
// clang-format off
LoopBuilder(&i, zero, M, 1)([&]{
LoopBuilder(&j, zero, N, 1)([&]{
llvmCall(
retTy,
rewriter.getSymbolRefAttr(printfFunc),
{extractvalue(i8PtrTy, fmtCst, rewriter.getIndexArrayAttr(0)),
iOp(i, j)});
});
llvmCall(
retTy,
rewriter.getSymbolRefAttr(printfFunc),
{extractvalue(i8PtrTy, fmtEol, rewriter.getIndexArrayAttr(0))});
});
// clang-format on
}
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
private:
// Turn a string into a toy.alloc (malloc/free abstraction) and a sequence
// of stores into the buffer, and return a MemRef into the buffer.
Value *getConstantCharBuffer(PatternRewriter &builder, Location loc,
StringRef data) const {
auto retTy =
builder.getMemRefType(data.size() + 1, builder.getIntegerType(8));
Value *result = builder.create<toy::AllocOp>(loc, retTy).getResult();
using namespace edsc;
using intrinsics::constant_index;
using intrinsics::constant_int;
ScopedContext scope(builder, loc);
MemRefView vOp(result);
IndexedValue iOp(result);
for (uint64_t i = 0; i < data.size(); ++i) {
iOp(constant_index(i)) = constant_int(data[i], 8);
}
iOp(constant_index(data.size())) = constant_int(0, 8);
return result;
}
/// Return the prototype declaration for printf in the module, create it if
/// necessary.
LLVM::LLVMFuncOp getPrintf(ModuleOp module) const {
auto printfFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("printf");
if (printfFunc)
return printfFunc;
// Create a function declaration for printf, signature is `i32 (i8*, ...)`
OpBuilder builder(module.getBodyRegion());
auto *dialect =
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo();
auto printfTy = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
/*isVarArg=*/true);
return builder.create<LLVM::LLVMFuncOp>(builder.getUnknownLoc(), "printf",
printfTy,
ArrayRef<NamedAttribute>());
}
};
/// Lowers constant to a sequence of store in a buffer.
class ConstantOpConversion : public ConversionPattern {
public:
explicit ConstantOpConversion(MLIRContext *context)
: ConversionPattern(toy::ConstantOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
toy::ConstantOp cstOp = cast<toy::ConstantOp>(op);
auto loc = cstOp.getLoc();
auto retTy = cstOp.getResult()->getType().cast<toy::ToyArrayType>();
auto shape = retTy.getShape();
Value *result = memRefTypeCast(
rewriter, rewriter.create<toy::AllocOp>(loc, retTy).getResult());
auto cstValue = cstOp.getValue();
auto f64Ty = rewriter.getF64Type();
using namespace edsc;
using intrinsics::constant_float;
using intrinsics::constant_index;
ScopedContext scope(rewriter, loc);
MemRefView vOp(result);
IndexedValue iOp(result);
for (uint64_t i = 0, ie = shape[0]; i < ie; ++i) {
if (shape.size() == 1) {
auto value = cstValue.getValue<APFloat>(ArrayRef<uint64_t>{i});
iOp(constant_index(i)) = constant_float(value, f64Ty);
continue;
}
for (uint64_t j = 0, je = shape[1]; j < je; ++j) {
auto value = cstValue.getValue<APFloat>(ArrayRef<uint64_t>{i, j});
iOp(constant_index(i), constant_index(j)) =
constant_float(value, f64Ty);
}
}
rewriter.replaceOp(op, result);
return matchSuccess();
}
};
/// Lower transpose operation to an affine loop nest.
class TransposeOpConversion : public ConversionPattern {
public:
explicit TransposeOpConversion(MLIRContext *context)
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto transpose = cast<toy::TransposeOp>(op);
auto loc = transpose.getLoc();
Value *result = memRefTypeCast(
rewriter,
rewriter.create<toy::AllocOp>(loc, transpose.getResult()->getType())
.getResult());
Value *operand = memRefTypeCast(rewriter, operands[0]);
using namespace edsc;
ScopedContext scope(rewriter, loc);
ValueHandle zero = intrinsics::constant_index(0);
MemRefView vRes(result), vOperand(operand);
IndexedValue iRes(result), iOperand(operand);
IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1));
// clang-format off
LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{
iRes(i, j) = iOperand(j, i);
});
// clang-format on
rewriter.replaceOp(op, {typeCast(rewriter, result, transpose.getType())});
return matchSuccess();
}
};
// Lower toy.return to standard return operation.
class ReturnOpConversion : public ConversionPattern {
public:
explicit ReturnOpConversion(MLIRContext *context)
: ConversionPattern(toy::ReturnOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
// Argument is optional, handle both cases.
if (op->getNumOperands())
rewriter.replaceOpWithNewOp<ReturnOp>(op, operands[0]);
else
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
}
};
/// This is the main class registering our individual converter classes with
/// the DialectConversion framework in MLIR.
class ToyTypeConverter : public TypeConverter {
protected:
/// Convert a Toy type, this gets called for block and region arguments, and
/// attributes.
Type convertType(Type t) override {
if (auto array = t.dyn_cast<toy::ToyArrayType>())
return array.toMemref();
return t;
}
/// Materialize a conversion to allow for partial lowering of types.
Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
ArrayRef<Value *> inputs,
Location loc) override {
assert(inputs.size() == 1 && "expected only one input value");
return rewriter.create<toy::TypeCastOp>(loc, inputs[0], resultType);
}
};
/// This is lowering to Linalg the parts that can be (matmul and add on arrays)
/// and is targeting LLVM otherwise.
struct LateLoweringPass : public ModulePass<LateLoweringPass> {
void runOnModule() override {
ToyTypeConverter typeConverter;
OwningRewritePatternList toyPatterns;
toyPatterns.insert<AddOpConversion, PrintOpConversion, ConstantOpConversion,
TransposeOpConversion, ReturnOpConversion>(
&getContext());
mlir::populateFuncOpTypeConversionPattern(toyPatterns, &getContext(),
typeConverter);
// Perform Toy specific lowering.
ConversionTarget target(getContext());
target.addLegalDialect<AffineOpsDialect, linalg::LinalgDialect,
LLVM::LLVMDialect, StandardOpsDialect>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter.isSignatureLegal(op.getType());
});
target.addLegalOp<toy::AllocOp, toy::TypeCastOp>();
if (failed(applyPartialConversion(getModule(), target, toyPatterns,
&typeConverter))) {
emitError(UnknownLoc::get(getModule().getContext()),
"error lowering Toy\n");
signalPassFailure();
}
// At this point the IR is almost using only standard and affine dialects.
// A few things remain before we emit LLVM IR. First to reuse as much of
// MLIR as possible we will try to lower everything to the standard and/or
// affine dialect: they already include conversion to the LLVM dialect.
// First patch calls type to return memref instead of ToyArray
for (auto function : getModule().getOps<FuncOp>()) {
function.walk([&](Operation *op) {
auto callOp = dyn_cast<CallOp>(op);
if (!callOp)
return;
if (!callOp.getNumResults())
return;
auto retToyTy =
callOp.getResult(0)->getType().dyn_cast<toy::ToyArrayType>();
if (!retToyTy)
return;
callOp.getResult(0)->setType(retToyTy.toMemref());
});
}
for (auto function : getModule().getOps<FuncOp>()) {
function.walk([&](Operation *op) {
// Turns toy.alloc into sequence of alloc/dealloc (later malloc/free).
if (auto allocOp = dyn_cast<toy::AllocOp>(op)) {
auto result = allocTensor(allocOp);
allocOp.replaceAllUsesWith(result);
allocOp.erase();
return;
}
// Eliminate all type.cast before lowering to LLVM.
if (auto typeCastOp = dyn_cast<toy::TypeCastOp>(op)) {
typeCastOp.replaceAllUsesWith(typeCastOp.getOperand());
typeCastOp.erase();
return;
}
});
}
// Lower Linalg to affine
for (auto function : getModule().getOps<FuncOp>())
linalg::lowerToLoops(function);
getModule().dump();
// Finally convert to LLVM Dialect
linalg::convertLinalg3ToLLVM(getModule());
}
/// Allocate buffers (malloc/free) for Toy operations. This can't be done as
/// part of dialect conversion framework since we need to insert `dealloc`
/// operations just before the return, but the conversion framework is
/// operating in a brand new function: we don't have the return to hook the
/// dealloc operations.
Value *allocTensor(toy::AllocOp alloc) {
OpBuilder builder(alloc);
auto retTy = alloc.getResult()->getType();
auto memRefTy = retTy.dyn_cast<MemRefType>();
if (!memRefTy)
memRefTy = retTy.cast<toy::ToyArrayType>().toMemref();
if (!memRefTy) {
alloc.emitOpError("is expected to allocate a Toy array or a MemRef");
llvm_unreachable("fatal error");
}
auto loc = alloc.getLoc();
Value *result = builder.create<AllocOp>(loc, memRefTy).getResult();
// Insert a `dealloc` operation right before the `return` operations, unless
// it is returned itself in which case the caller is responsible for it.
alloc.getParentRegion()->walk([&](Operation *op) {
auto returnOp = dyn_cast<ReturnOp>(op);
if (!returnOp)
return;
if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc)
return;
builder.setInsertionPoint(returnOp);
builder.create<DeallocOp>(alloc.getLoc(), result);
});
return result;
}
};
} // end anonymous namespace
namespace toy {
std::unique_ptr<mlir::Pass> createLateLoweringPass() {
return std::make_unique<LateLoweringPass>();
}
} // namespace toy

View File

@ -0,0 +1,318 @@
//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a partial lowering of Toy operations to a combination of
// affine loops and standard operations. This lowering expects that all calls
// have been inlined, and all shapes have been resolved.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "toy/Passes.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns
//===----------------------------------------------------------------------===//
/// Convert the given TensorType into the corresponding MemRefType.
static MemRefType convertTensorToMemRef(TensorType type) {
assert(type.hasRank() && "expected only ranked shapes");
return MemRefType::get(type.getShape(), type.getElementType());
}
/// Insert an allocation and deallocation for the given MemRefType.
static Value *insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter) {
auto alloc = rewriter.create<AllocOp>(loc, type);
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc.getOperation()->getBlock();
alloc.getOperation()->moveBefore(&parentBlock->front());
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
return alloc;
}
/// This defines the function type used to process an iteration of a lowered
/// loop. It takes as input a rewriter, an array of memRefOperands corresponding
/// to the operands of the input operation, and the set of loop induction
/// variables for the iteration. It returns a value to store at the current
/// index of the iteration.
using LoopIterationFn = function_ref<Value *(PatternRewriter &rewriter,
ArrayRef<Value *> memRefOperands,
ArrayRef<Value *> loopIvs)>;
static void lowerOpToLoops(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
// Create an empty affine loop for each of the dimensions within the shape.
SmallVector<Value *, 4> loopIvs;
for (auto dim : tensorType.getShape()) {
auto loop = rewriter.create<AffineForOp>(loc, /*lb=*/0, dim, /*step=*/1);
loop.getBody()->clear();
loopIvs.push_back(loop.getInductionVar());
// Terminate the loop body and update the rewriter insertion point to the
// beginning of the loop.
rewriter.setInsertionPointToStart(loop.getBody());
rewriter.create<AffineTerminatorOp>(loc);
rewriter.setInsertionPointToStart(loop.getBody());
}
// Generate a call to the processing function with the rewriter, the memref
// operands, and the loop induction variables. This function will return the
// value to store at the current index.
Value *valueToStore = processIteration(rewriter, operands, loopIvs);
rewriter.create<AffineStoreOp>(loc, valueToStore, alloc,
llvm::makeArrayRef(loopIvs));
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
}
namespace {
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Binary operations
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext *ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
lowerOpToLoops(
op, operands, rewriter,
[loc](PatternRewriter &rewriter, ArrayRef<Value *> memRefOperands,
ArrayRef<Value *> loopIvs) {
// Generate an adaptor for the remapped operands of the BinaryOp. This
// allows for using the nice named accessors that are generated by the
// ODS.
typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands);
// Generate loads for the element of 'lhs' and 'rhs' at the inner
// loop.
auto loadedLhs =
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
auto loadedRhs =
rewriter.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
// Create the binary operation performed on the loaded values.
return rewriter.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
});
return matchSuccess();
}
};
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Constant operations
//===----------------------------------------------------------------------===//
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
Location loc = op.getLoc();
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
auto tensorType = op.getType().cast<TensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
// We will be generating constant indices up-to the largest dimension.
// Create these constants up-front to avoid large amounts of redundant
// operations.
auto valueShape = memRefType.getShape();
SmallVector<Value *, 8> constantIndices;
for (auto i : llvm::seq<int64_t>(
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
// The constant operation represents a multi-dimensional constant, so we
// will need to generate a store for each of the elements. The following
// functor recursively walks the dimensions of the constant shape,
// generating a store when the recursion hits the base case.
SmallVector<Value *, 2> indices;
auto valueIt = constantValue.getValues<FloatAttr>().begin();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
// The last dimension is the base case of the recursion, at this point
// we store the element at the given index.
if (dimension == valueShape.size()) {
rewriter.create<AffineStoreOp>(
loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
llvm::makeArrayRef(indices));
return;
}
// Otherwise, iterate over the current dimension and add the indices to
// the list.
for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
indices.push_back(constantIndices[i]);
storeElements(dimension + 1);
indices.pop_back();
}
};
// Start the element storing recursion from the first dimension.
storeElements(/*dimension=*/0);
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Return operations
//===----------------------------------------------------------------------===//
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(toy::ReturnOp op,
PatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
return matchFailure();
// We lower "toy.return" directly to "std.return".
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Transpose operations
//===----------------------------------------------------------------------===//
struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
lowerOpToLoops(
op, operands, rewriter,
[loc](PatternRewriter &rewriter, ArrayRef<Value *> memRefOperands,
ArrayRef<Value *> loopIvs) {
// Generate an adaptor for the remapped operands of the TransposeOp.
// This allows for using the nice named accessors that are generated
// by the ODS.
toy::TransposeOpOperandAdaptor tranposeAdaptor(memRefOperands);
Value *input = tranposeAdaptor.input();
// Transpose the elements by generating a load from the reverse
// indices.
SmallVector<Value *, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<AffineLoadOp>(loc, input, reverseIvs);
});
return matchSuccess();
}
};
} // end anonymous namespace.
//===----------------------------------------------------------------------===//
// ToyToAffineLoweringPass
//===----------------------------------------------------------------------===//
/// This is a partial lowering to affine loops of the toy operations that are
/// computationally intensive (like matmul for example...) while keeping the
/// rest of the code in the Toy dialect.
namespace {
struct ToyToAffineLoweringPass : public FunctionPass<ToyToAffineLoweringPass> {
void runOnFunction() final;
};
} // end anonymous namespace.
void ToyToAffineLoweringPass::runOnFunction() {
auto function = getFunction();
// We only lower the main function as we expect that all other functions have
// been inlined.
if (function.getName() != "main")
return;
// Verify that the given main has no inputs and results.
if (function.getNumArguments() || function.getType().getNumResults()) {
function.emitError("expected 'main' to have 0 inputs and 0 results");
return signalPassFailure();
}
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine` and `Standard` dialects.
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
// a partial lowering, we explicitly mark the Toy operations that don't want
// to lower, `toy.print`, as `legal`.
target.addIllegalDialect<toy::ToyDialect>();
target.addLegalOp<toy::PrintOp>();
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
OwningRewritePatternList patterns;
patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering,
ReturnOpLowering, TransposeOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
if (failed(applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
}
/// Create a pass for lowering operations in the `Affine` and `Std` dialects,
/// for a subset of the Toy IR (e.g. matmul).
std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() {
return std::make_unique<ToyToAffineLoweringPass>();
}

View File

@ -25,30 +25,30 @@
#include "toy/Dialect.h"
#include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
using namespace mlir::toy;
using namespace toy;
using llvm::ArrayRef;
using llvm::cast;
using llvm::dyn_cast;
using llvm::isa;
using llvm::makeArrayRef;
using llvm::ScopedHashTableScope;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
using std::make_unique;
namespace {
@ -57,56 +57,43 @@ namespace {
/// This will emit operations that are specific to the Toy language, preserving
/// the semantics of the language and (hopefully) allow to perform accurate
/// analysis and transformation based on these high level semantics.
///
/// At this point we take advantage of the "raw" MLIR APIs to create operations
/// that haven't been registered in any way with MLIR. These operations are
/// unknown to MLIR, custom passes could operate by string-matching the name of
/// these operations, but no other type checking or semantic is associated with
/// them natively by MLIR.
class MLIRGenImpl {
public:
MLIRGenImpl(mlir::MLIRContext &context) : context(context) {}
MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
/// Public API: convert the AST for a Toy module (source file) to an MLIR
/// Module.
mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
/// Module operation.
mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
// We create an empty MLIR module and codegen functions one at a time and
// add them to the module.
theModule = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &F : moduleAST) {
auto func = mlirGen(F);
if (!func)
return nullptr;
theModule->push_back(func);
theModule.push_back(func);
}
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
// this won't do much, but it should at least check some structural
// properties.
if (failed(mlir::verify(*theModule))) {
emitError(mlir::UnknownLoc::get(&context), "module verification error");
// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
// have on the Toy operations.
if (failed(mlir::verify(theModule))) {
theModule.emitError("module verification error");
return nullptr;
}
return std::move(theModule);
return theModule;
}
private:
/// In MLIR (like in LLVM) a "context" object holds the memory allocation and
/// the ownership of many internal structure of the IR and provide a level
/// of "uniquing" across multiple modules (types for instance).
mlir::MLIRContext &context;
/// A "module" matches a Toy source file: containing a list of functions.
mlir::ModuleOp theModule;
/// A "module" matches a source file: it contains a list of functions.
mlir::OwningModuleRef theModule;
/// The builder is a helper class to create IR inside a function. It is
/// re-initialized every time we enter a function and kept around as a
/// convenience for emitting individual operations.
/// The builder is stateful, in particular it keeeps an "insertion point":
/// this is where the next operations will be introduced.
std::unique_ptr<mlir::OpBuilder> builder;
/// The builder is a helper class to create IR inside a function. The builder
/// is stateful, in particular it keeeps an "insertion point": this is where
/// the next operations will be introduced.
mlir::OpBuilder builder;
/// The symbol table maps a variable name to a value in the current scope.
/// Entering a function creates a new scope, and the function arguments are
@ -116,37 +103,35 @@ private:
/// Helper conversion for a Toy AST location to an MLIR location.
mlir::Location loc(Location loc) {
return mlir::FileLineColLoc::get(mlir::Identifier::get(*loc.file, &context),
loc.line, loc.col, &context);
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
loc.col);
}
/// Declare a variable in the current scope, return true if the variable
/// Declare a variable in the current scope, return success if the variable
/// wasn't declared yet.
bool declare(llvm::StringRef var, mlir::Value *value) {
if (symbolTable.count(var)) {
return false;
}
mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) {
if (symbolTable.count(var))
return mlir::failure();
symbolTable.insert(var, value);
return true;
return mlir::success();
}
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());
// This is a generic function, the return type will be inferred later.
llvm::SmallVector<mlir::Type, 4> ret_types;
// Arguments type is uniformly a generic array.
// Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(),
func_type, /* attrs = */ {});
auto func_type = builder.getFunctionType(arg_types, llvm::None);
auto function = mlir::FuncOp::create(location, proto.getName(), func_type);
// Mark the function as generic: it'll require type specialization for every
// call site.
if (function.getNumArguments())
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
function.setAttr("toy.generic", builder.getUnitAttr());
return function;
}
@ -165,29 +150,39 @@ private:
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
auto &protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
llvm::zip(protoArgs, entryBlock.getArguments())) {
declare(std::get<0>(name_value)->getName(), std::get<1>(name_value));
if (failed(declare(std::get<0>(name_value)->getName(),
std::get<1>(name_value))))
return nullptr;
}
// Create a builder for the function, it will be used throughout the codegen
// to create operations in this function.
builder = std::make_unique<mlir::OpBuilder>(function.getBody());
// Set the insertion point in the builder to the beginning of the function
// body, it will be used throughout the codegen to create operations in this
// function.
builder.setInsertionPointToStart(&entryBlock);
// Emit the body of the function.
if (!mlirGen(*funcAST.getBody())) {
if (mlir::failed(mlirGen(*funcAST.getBody()))) {
function.erase();
return nullptr;
}
// Implicitly return void if no return statement was emited.
// Implicitly return void if no return statement was emitted.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
if (function.getBlocks().back().back().getName().getStringRef() !=
"toy.return") {
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
mlirGen(fakeRet);
ReturnOp returnOp;
if (!entryBlock.empty())
returnOp = dyn_cast<ReturnOp>(entryBlock.back());
if (!returnOp) {
builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
} else if (returnOp.hasOperand()) {
// Otherwise, if this return operation has an operand then add a result to
// the function.
function.setType(builder.getFunctionType(function.getType().getInputs(),
getType(VarType{})));
}
return function;
@ -206,11 +201,11 @@ private:
// and the result value is returned. If an error occurs we get a nullptr
// and propagate.
//
mlir::Value *L = mlirGen(*binop.getLHS());
if (!L)
mlir::Value *lhs = mlirGen(*binop.getLHS());
if (!lhs)
return nullptr;
mlir::Value *R = mlirGen(*binop.getRHS());
if (!R)
mlir::Value *rhs = mlirGen(*binop.getRHS());
if (!rhs)
return nullptr;
auto location = loc(binop.loc());
@ -218,123 +213,113 @@ private:
// support '+' and '*'.
switch (binop.getOp()) {
case '+':
return builder->create<AddOp>(location, L, R).getResult();
break;
return builder.create<AddOp>(location, lhs, rhs);
case '*':
return builder->create<MulOp>(location, L, R).getResult();
default:
emitError(loc(binop.loc()), "error: invalid binary operator '")
<< binop.getOp() << "'";
return nullptr;
return builder.create<MulOp>(location, lhs, rhs);
}
emitError(location, "invalid binary operator '") << binop.getOp() << "'";
return nullptr;
}
// This is a reference to a variable in an expression. The variable is
// expected to have been declared and so should have a value in the symbol
// table, otherwise emit an error and return nullptr.
/// This is a reference to a variable in an expression. The variable is
/// expected to have been declared and so should have a value in the symbol
/// table, otherwise emit an error and return nullptr.
mlir::Value *mlirGen(VariableExprAST &expr) {
if (symbolTable.count(expr.getName()))
return symbolTable.lookup(expr.getName());
if (auto *variable = symbolTable.lookup(expr.getName()))
return variable;
emitError(loc(expr.loc()), "error: unknown variable '")
<< expr.getName() << "'";
return nullptr;
}
// Emit a return operation, return true on success.
bool mlirGen(ReturnExprAST &ret) {
/// Emit a return operation. This will return failure if any generation fails.
mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
auto location = loc(ret.loc());
// `return` takes an optional expression, we need to account for it here.
if (!ret.getExpr().hasValue()) {
builder->create<ReturnOp>(location);
return true;
// 'return' takes an optional expression, handle that case here.
mlir::Value *expr = nullptr;
if (ret.getExpr().hasValue()) {
if (!(expr = mlirGen(*ret.getExpr().getValue())))
return mlir::failure();
}
auto *expr = mlirGen(*ret.getExpr().getValue());
if (!expr)
return false;
builder->create<ReturnOp>(location, expr);
return true;
// Otherwise, this return operation has zero operands.
builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
: ArrayRef<mlir::Value *>());
return mlir::success();
}
// Emit a literal/constant array. It will be emitted as a flattened array of
// data in an Attribute attached to a `toy.constant` operation.
// See documentation on [Attributes](LangRef.md#attributes) for more details.
// Here is an excerpt:
//
// Attributes are the mechanism for specifying constant data in MLIR in
// places where a variable is never allowed [...]. They consist of a name
// and a [concrete attribute value](#attribute-values). It is possible to
// attach attributes to operations, functions, and function arguments. The
// set of expected attributes, their structure, and their interpretation
// are all contextually dependent on what they are attached to.
//
// Example, the source level statement:
// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
// will be converted to:
// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64>
//
/// Emit a literal/constant array. It will be emitted as a flattened array of
/// data in an Attribute attached to a `toy.constant` operation.
/// See documentation on [Attributes](LangRef.md#attributes) for more details.
/// Here is an excerpt:
///
/// Attributes are the mechanism for specifying constant data in MLIR in
/// places where a variable is never allowed [...]. They consist of a name
/// and a concrete attribute value. The set of expected attributes, their
/// structure, and their interpretation are all contextually dependent on
/// what they are attached to.
///
/// Example, the source level statement:
/// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
/// will be converted to:
/// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
/// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
/// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
///
mlir::Value *mlirGen(LiteralExprAST &lit) {
auto location = loc(lit.loc());
// The attribute is a vector with an attribute per element (number) in the
// array, see `collectData()` below for more details.
std::vector<mlir::Attribute> data;
auto type = getType(lit.getDims());
// The attribute is a vector with a floating point value per element
// (number) in the array, see `collectData()` below for more details.
std::vector<double> data;
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
std::multiplies<int>()));
collectData(lit, data);
// FIXME: using a tensor type is a HACK here.
// Can we do differently without registering a dialect? Using a string blob?
mlir::Type elementType = mlir::FloatType::getF64(&context);
auto dataType = builder->getTensorType(lit.getDims(), elementType);
// The type of this attribute is tensor of 64-bit floating-point with the
// shape of the literal.
mlir::Type elementType = builder.getF64Type();
auto dataType = builder.getTensorType(lit.getDims(), elementType);
// This is the actual attribute that actually hold the list of values for
// this array literal.
auto dataAttribute = builder->getDenseElementsAttr(dataType, data)
.cast<mlir::DenseElementsAttr>();
// This is the actual attribute that holds the list of values for this
// tensor literal.
auto dataAttribute =
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
// Build the MLIR op `toy.constant`, only boilerplate below.
return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute)
.getResult();
// Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
// method.
return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
}
// Recursive helper function to accumulate the data that compose an array
// literal. It flattens the nested structure in the supplied vector. For
// example with this array:
// [[1, 2], [3, 4]]
// we will generate:
// [ 1, 2, 3, 4 ]
// Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`.
// Attributes are the way MLIR attaches constant to operations and functions.
void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) {
/// Recursive helper function to accumulate the data that compose an array
/// literal. It flattens the nested structure in the supplied vector. For
/// example with this array:
/// [[1, 2], [3, 4]]
/// we will generate:
/// [ 1, 2, 3, 4 ]
/// Individual numbers are represented as doubles.
/// Attributes are the way MLIR attaches constant to operations.
void collectData(ExprAST &expr, std::vector<double> &data) {
if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
for (auto &value : lit->getValues())
collectData(*value, data);
return;
}
assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
mlir::Type elementType = mlir::FloatType::getF64(&context);
auto attr = mlir::FloatAttr::getChecked(
elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc()));
data.push_back(attr);
data.push_back(cast<NumberExprAST>(expr).getValue());
}
// Emit a call expression. It emits specific operations for the `transpose`
// builtin. Other identifiers are assumed to be user-defined functions.
/// Emit a call expression. It emits specific operations for the `transpose`
/// builtin. Other identifiers are assumed to be user-defined functions.
mlir::Value *mlirGen(CallExprAST &call) {
llvm::StringRef callee = call.getCallee();
auto location = loc(call.loc());
std::string callee = call.getCallee();
if (callee == "transpose") {
if (call.getArgs().size() != 1) {
emitError(location, "MLIR codegen encountered an error: toy.transpose "
"does not accept multiple arguments");
return nullptr;
}
mlir::Value *arg = mlirGen(*call.getArgs()[0]);
return builder->create<TransposeOp>(location, arg).getResult();
}
// Codegen the operands first
// Codegen the operands first.
SmallVector<mlir::Value *, 4> operands;
for (auto &expr : call.getArgs()) {
auto *arg = mlirGen(*expr);
@ -342,34 +327,41 @@ private:
return nullptr;
operands.push_back(arg);
}
// Calls to user-defined function are mapped to a custom call that takes
// the callee name as an attribute.
return builder->create<GenericCallOp>(location, call.getCallee(), operands)
.getResult();
// Builting calls have their custom operation, meaning this is a
// straightforward emission.
if (callee == "transpose") {
if (call.getArgs().size() != 1) {
emitError(location, "MLIR codegen encountered an error: toy.transpose "
"does not accept multiple arguments");
return nullptr;
}
return builder.create<TransposeOp>(location, operands[0]);
}
// Otherwise this is a call to a user-defined function. Calls to ser-defined
// functions are mapped to a custom call that takes the callee name as an
// attribute.
return builder.create<GenericCallOp>(location, callee, operands);
}
// Emit a call expression. It emits specific operations for two builtins:
// transpose(x) and print(x). Other identifiers are assumed to be user-defined
// functions. Return false on failure.
bool mlirGen(PrintExprAST &call) {
/// Emit a print expression. It emits specific operations for two builtins:
/// transpose(x) and print(x).
mlir::LogicalResult mlirGen(PrintExprAST &call) {
auto *arg = mlirGen(*call.getArg());
if (!arg)
return false;
auto location = loc(call.loc());
builder->create<PrintOp>(location, arg);
return true;
return mlir::failure();
builder.create<PrintOp>(loc(call.loc()), arg);
return mlir::success();
}
// Emit a constant for a single number (FIXME: semantic? broadcast?)
/// Emit a constant for a single number (FIXME: semantic? broadcast?)
mlir::Value *mlirGen(NumberExprAST &num) {
auto location = loc(num.loc());
mlir::Type elementType = mlir::FloatType::getF64(&context);
auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
loc(num.loc()));
return builder->create<ConstantOp>(location, attr).getResult();
return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
}
// Dispatch codegen for the right expression subclass using RTTI.
/// Dispatch codegen for the right expression subclass using RTTI.
mlir::Value *mlirGen(ExprAST &expr) {
switch (expr.getKind()) {
case toy::ExprAST::Expr_BinOp:
@ -383,84 +375,82 @@ private:
case toy::ExprAST::Expr_Num:
return mlirGen(cast<NumberExprAST>(expr));
default:
emitError(loc(expr.loc()),
"MLIR codegen encountered an unhandled expr kind '")
emitError(loc(expr.loc()))
<< "MLIR codegen encountered an unhandled expr kind '"
<< Twine(expr.getKind()) << "'";
return nullptr;
}
}
// Handle a variable declaration, we'll codegen the expression that forms the
// initializer and record the value in the symbol table before returning it.
// Future expressions will be able to reference this variable through symbol
// table lookup.
/// Handle a variable declaration, we'll codegen the expression that forms the
/// initializer and record the value in the symbol table before returning it.
/// Future expressions will be able to reference this variable through symbol
/// table lookup.
mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
mlir::Value *value = nullptr;
auto location = loc(vardecl.loc());
if (auto init = vardecl.getInitVal()) {
value = mlirGen(*init);
if (!value)
return nullptr;
// We have the initializer value, but in case the variable was declared
// with specific shape, we emit a "reshape" operation. It will get
// optimized out later as needed.
if (!vardecl.getType().shape.empty()) {
value = builder
->create<ReshapeOp>(
location, value,
getType(vardecl.getType()).cast<ToyArrayType>())
.getResult();
}
} else {
auto init = vardecl.getInitVal();
if (!init) {
emitError(loc(vardecl.loc()),
"missing initializer in variable declaration");
return nullptr;
}
// Register the value in the symbol table
declare(vardecl.getName(), value);
mlir::Value *value = mlirGen(*init);
if (!value)
return nullptr;
// We have the initializer value, but in case the variable was declared
// with specific shape, we emit a "reshape" operation. It will get
// optimized out later as needed.
if (!vardecl.getType().shape.empty()) {
value = builder.create<ReshapeOp>(loc(vardecl.loc()),
getType(vardecl.getType()), value);
}
// Register the value in the symbol table.
if (failed(declare(vardecl.getName(), value)))
return nullptr;
return value;
}
/// Codegen a list of expression, return false if one of them hit an error.
bool mlirGen(ExprASTList &blockAST) {
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
/// Codegen a list of expression, return failure if one of them hit an error.
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
ScopedHashTableScope<StringRef, mlir::Value *> var_scope(symbolTable);
for (auto &expr : blockAST) {
// Specific handling for variable declarations, return statement, and
// print. These can only appear in block list and not in nested
// expressions.
if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
if (!mlirGen(*vardecl))
return false;
return mlir::failure();
continue;
}
if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) {
if (!mlirGen(*ret))
return false;
return true;
}
if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
return mlirGen(*ret);
if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
if (!mlirGen(*print))
return false;
if (mlir::failed(mlirGen(*print)))
return mlir::success();
continue;
}
// Generic expression dispatch codegen.
if (!mlirGen(*expr))
return false;
return mlir::failure();
}
return true;
return mlir::success();
}
/// Build a type from a list of shape dimensions. Types are `array` followed
/// by an optional dimension list, example: array<2, 2>
/// They are wrapped in a `toy` dialect (see next chapter) and get printed:
/// !toy.array<2, 2>
template <typename T> mlir::Type getType(T shape) {
SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
return ToyArrayType::get(&context, shape64);
/// Build a tensor type from a list of shape dimensions.
mlir::Type getType(ArrayRef<int64_t> shape) {
// If the shape is empty, then this type is unranked.
if (shape.empty())
return builder.getTensorType(builder.getF64Type());
// Otherwise, we use the given shape.
return builder.getTensorType(shape, builder.getF64Type());
}
/// Build an MLIR type from a Toy AST variable type
/// (forward to the generic getType(T) above).
/// Build an MLIR type from a Toy AST variable type (forward to the generic
/// getType above).
mlir::Type getType(const VarType &type) { return getType(type.shape); }
};

View File

@ -1,4 +1,4 @@
//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===//
//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@ -15,213 +15,55 @@
// limitations under the License.
// =============================================================================
//
// This file implements a Module level pass performing interprocedural
// This file implements a Function level pass performing interprocedural
// propagation of array shapes through function specialization.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "toy/Dialect.h"
#include "toy/Passes.h"
#include "toy/ShapeInferenceInterface.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#define DEBUG_TYPE "toy-shape-inference"
#define DEBUG_TYPE "shape-inference"
using namespace mlir;
using namespace toy;
using llvm::MutableArrayRef;
using llvm::SmallVector;
using llvm::SmallVectorImpl;
using llvm::StringRef;
using llvm::Twine;
/// Create mangled name for function specialization. We will simply append the
/// shape of the arguments to the function name. For example calling
///
/// "toy.generic_call"(%1, %3) {callee: "foo"}
/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array">
///
/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could
/// have provide a function with a similar name. But we will claim this as a
/// feature: this allow the user to provide custom specialization!
static std::string mangle(StringRef funcName,
MutableArrayRef<mlir::OpOperand> operands) {
std::string mangledName;
mangledName.reserve(funcName.size() + operands.size() * 6);
mangledName = funcName;
for (auto &operand : operands) {
auto arrayTy = operand.get()->getType().cast<ToyArrayType>();
mangledName += "_";
const char *sep = "";
for (auto dim : arrayTy.getShape()) {
mangledName += (sep + Twine(dim)).str();
sep = "x";
}
}
return mangledName;
}
/// Include the auto-generated definitions for the shape inference interfaces.
#include "toy/ShapeInferenceOpInterfaces.cpp.inc"
namespace {
/// The ShapeInferencePass is a ModulePass: it will run on the Module as a
/// whole. MLIR also supports FunctionPass which are restricted to modify a
/// single function at a time. This pass couldn't be a function pass due the
/// nature of its interprocedural transformations.
/// The ShapeInferencePass is a FunctionPass that performs intra-procedural
/// shape inference.
///
/// The algorithm has two levels, first intra-procedurally:
/// Algorithm:
///
/// 1) Build a worklist containing all the operations that are returning
/// a generic Toy array: these are the operations that need shape
/// 1) Build a worklist containing all the operations that return a
/// dynamically shaped tensor: these are the operations that need shape
/// inference.
/// 2) Iterate on the worklist:
/// a) find an operation to process: the next ready operation in the
/// worklist has all of its arguments non-generic,
/// b) if no operation is found, break out of the loop,
/// c) remove the operation from the worklist,
/// d) infer the shape of its output from the arguments type.
/// 3) If the worklist is empty, the algorithm succeeded and we infer the
/// return type for the function from the return operation.
/// d) infer the shape of its output from the argument types.
/// 3) If the worklist is empty, the algorithm succeeded.
///
/// There is a twist though: when a call to a generic function is encountered,
/// shape inference requires the return type of the callee to be inferred first.
/// At this point we need to run specialize the callee by cloning it. Here is
/// the inter-procedural flow:
///
/// 1) Keep a worklist of function to process. Start with function "main".
/// 2) While the worklist isn't empty:
/// a) Take the last inserted function in the worklist.
/// b) Run the intra-procedural shape inference on this function.
/// c) If the intra-procedural shape inference can't complete, it returns
/// a FuncOp that needs to be inferred first. In this case, queue this
/// new function and continue. Otherwise the inference succeeded and we
/// can pop from the queue.
///
class ShapeInferencePass : public mlir::ModulePass<ShapeInferencePass> {
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
public:
// One entry in the inter-procedural worklist. It keeps track of the
// function to process, the mangled name for this specialization, and the
// types of the arguments on which to specialize.
struct FunctionToSpecialize {
mlir::FuncOp function;
std::string mangledName;
SmallVector<mlir::Type, 4> argumentsType;
};
void runOnModule() override {
auto module = getModule();
mlir::ModuleManager moduleManager(module);
auto main = moduleManager.lookupSymbol<mlir::FuncOp>("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
"shape inference failed: can't find a main function\n");
signalPassFailure();
return;
}
/// Inter-procedural loop, initialize with `main` and iterate till
/// successfully infer the full reachable call-graph from main.
SmallVector<FunctionToSpecialize, 8> worklist;
worklist.push_back({main, "", {}});
while (!worklist.empty()) {
if (failed(specialize(worklist, moduleManager)))
return;
}
// Delete any generic function left
// FIXME: we may want this as a separate pass.
for (mlir::FuncOp function :
llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
if (auto genericAttr =
function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
if (genericAttr.getValue())
function.erase();
}
}
}
/// Run inference on a function. If a mangledName is provided, we need to
/// specialize the function: to this end clone it first.
mlir::LogicalResult
specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist,
mlir::ModuleManager &moduleManager) {
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
mlir::FuncOp f = functionToSpecialize.function;
// Check if cloning for specialization is needed (usually anything but main)
// We will create a new function with the concrete types for the parameters
// and clone the body into it.
if (!functionToSpecialize.mangledName.empty()) {
if (moduleManager.lookupSymbol<mlir::FuncOp>(
functionToSpecialize.mangledName)) {
funcWorklist.pop_back();
// FuncOp already specialized, move on.
return mlir::success();
}
// Create a new function with a generic array return type, it will be
// updated when the inference for the function body completes.
auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
{ToyArrayType::get(&getContext())},
&getContext());
auto newFunction =
mlir::FuncOp::create(f.getLoc(), functionToSpecialize.mangledName,
type, f.getDialectAttrs());
moduleManager.insert(newFunction);
// Clone the function body
mlir::BlockAndValueMapping mapper;
f.cloneInto(newFunction, mapper);
LLVM_DEBUG({
llvm::dbgs() << "====== Cloned : \n";
f.dump();
llvm::dbgs() << "====== Into : \n";
newFunction.dump();
});
f = newFunction;
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// Remap the entry-block arguments
// FIXME: this seems like a bug in `cloneInto()` above?
auto &entryBlock = f.getBlocks().front();
int blockArgSize = entryBlock.getArguments().size();
assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
entryBlock.addArguments(f.getType().getInputs());
auto argList = entryBlock.getArguments();
for (int argNum = 0; argNum < blockArgSize; ++argNum) {
argList[0]->replaceAllUsesWith(argList[blockArgSize]);
entryBlock.eraseArgument(0);
}
assert(succeeded(verify(f)));
}
LLVM_DEBUG(llvm::dbgs()
<< "Run shape inference on : '" << f.getName() << "'\n");
auto *toyDialect = getContext().getRegisteredDialect("toy");
if (!toyDialect) {
signalPassFailure();
return emitError(mlir::UnknownLoc::get(&getContext()),
"Toy dialect is not registered");
}
void runOnFunction() override {
auto f = getFunction();
// Populate the worklist with the operations that need shape inference:
// these are the Toy operations that return a generic array.
// these are operations that return a dynamic shape.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
f.walk([&](mlir::Operation *op) {
if (op->getDialect() == toyDialect) {
if (op->getNumResults() == 1 &&
op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
opWorklist.insert(op);
}
if (returnsDynamicShape(op))
opWorklist.insert(op);
});
// Iterate on the operations in the worklist until all operations have been
@ -229,152 +71,43 @@ public:
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
return !ty.cast<ToyArrayType>().isGeneric();
});
});
auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
if (nextop == opWorklist.end())
break; // failure: no operations can be inferred.
break;
mlir::Operation *op = *nextop;
Operation *op = *nextop;
opWorklist.erase(op);
// Ask the operation to infer its output shapes.
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
// The add operation is trivial: propagate the input type as is.
if (auto addOp = llvm::dyn_cast<AddOp>(op)) {
op->getResult(0)->setType(op->getOperand(0)->getType());
continue;
}
// Transpose is easy: just invert the dimensions.
if (op->getName().getStringRef() == "toy.transpose") {
SmallVector<int64_t, 2> dims;
auto arrayTy = op->getOperand(0)->getType().cast<ToyArrayType>();
dims.insert(dims.end(), arrayTy.getShape().begin(),
arrayTy.getShape().end());
if (dims.size() == 2)
std::swap(dims[0], dims[1]);
op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
continue;
}
// Multiplication is a bit trickier, handle rank 1 as dot product and rank
// 2 as matrix multiplications.
// We need to be careful about rank mismatch here: the verifier could
// catch it but shape inference earlier in the pass could generate an
// invalid IR (from an invalid Toy input of course) and we wouldn't want
// to crash here.
if (auto mulOp = llvm::dyn_cast<MulOp>(op)) {
auto lhs = mulOp.getLHS()->getType().cast<ToyArrayType>();
auto rhs = mulOp.getRHS()->getType().cast<ToyArrayType>();
auto lhsRank = lhs.getShape().size();
auto rhsRank = rhs.getShape().size();
if (lhsRank != rhsRank) {
return op->emitError("shape mismatch: LHS and RHS must have the same "
"rank for multiplication, got ")
<< lhsRank << " vs " << lhsRank;
}
SmallVector<int64_t, 2> dims;
if (lhsRank == 1) {
// dot product, result shape is <1>
dims.push_back(1);
} else {
if (lhsRank != 2) {
return op->emitError("shape mismatch: expect rank 1 or 2 for mul "
"operands, got ")
<< lhsRank;
}
dims.push_back(lhs.getShape()[0]);
dims.push_back(rhs.getShape()[1]);
}
op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
continue;
}
// Process calls: lookup the callee after mangling the name with the
// argument shapes. If the callee does not exist, we stop the inference
// for this function, queue the callee in the inter-procedural work list,
// and return. The current function stays in the work list and will
// restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName();
auto callee = moduleManager.lookupSymbol<mlir::FuncOp>(calleeName);
if (!callee) {
signalPassFailure();
return f.emitError("shape inference failed, call to unknown '")
<< calleeName << "'";
}
auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n");
auto mangledCallee =
moduleManager.lookupSymbol<mlir::FuncOp>(mangledName);
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
funcWorklist.push_back({callee, std::move(mangledName),
llvm::to_vector<4>(op->getOperandTypes())});
return mlir::success();
}
// Found a specialized callee! Let's turn this into a normal call
// operation.
SmallVector<mlir::Value *, 8> operands(op->getOperands());
mlir::OpBuilder builder(op);
auto newCall =
builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands);
if (newCall.getNumResults()) {
op->getResult(0)->replaceAllUsesWith(newCall.getResult(0));
op->erase();
continue;
}
if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
shapeOp.inferShapes();
} else {
op->emitError("unable to infer shape of operation without shape "
"inference interface");
return signalPassFailure();
}
}
// Done with inference on this function, removing it from the worklist.
funcWorklist.pop_back();
// Mark the function as non-generic now that inference has succeeded
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) {
f.emitError("Shape inference failed, ")
<< opWorklist.size() << " operations couldn't be inferred\n";
signalPassFailure();
auto diag = f.emitError("shape inference failed, ")
<< opWorklist.size() << " operations couldn't be inferred\n";
for (auto *ope : opWorklist)
diag << " - " << *ope << "\n";
return diag;
}
}
// Finally, update the return type of the function based on the argument to
// the return operation.
for (auto &block : f.getBlocks()) {
auto ret = llvm::cast<ReturnOp>(block.getTerminator());
if (!ret)
continue;
if (ret.getNumOperands() &&
f.getType().getResult(0) == ret.getOperand()->getType())
// type match, we're done
break;
SmallVector<mlir::Type, 1> retTy;
if (ret.getNumOperands())
retTy.push_back(ret.getOperand()->getType());
std::vector<mlir::Type> argumentsType;
for (auto arg : f.getArguments())
argumentsType.push_back(arg->getType());
auto newType =
mlir::FunctionType::get(argumentsType, retTy, &getContext());
f.setType(newType);
assert(succeeded(verify(f)));
break;
}
return mlir::success();
/// A utility method that returns if the given operation has a dynamically
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
return !resultType.isa<RankedTensorType>();
});
}
};
} // end anonymous namespace
namespace toy {
std::unique_ptr<mlir::Pass> createShapeInferencePass() {
/// Create a Shape Inference pass.
std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>();
}
} // namespace toy

View File

@ -15,24 +15,30 @@
// limitations under the License.
// =============================================================================
//
// This file implements a simple combiner for optimizing pattern in the Toy
// dialect.
// This file implements a set of simple combiners for optimizing operations in
// the Toy dialect.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "toy/Dialect.h"
#include <numeric>
namespace toy {
using namespace mlir;
using namespace toy;
namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "ToyCombine.inc"
} // end anonymous namespace
/// Fold transpose(transpose(x) -> transpose(x)
/// Fold simple cast operations that return the same type as the input.
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return mlir::impl::foldCastOp(*this);
}
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process
@ -40,9 +46,9 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
SimplifyRedundantTranspose(mlir::MLIRContext *context)
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
/// This method is attempting to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected
/// to interact with it to perform any changes to the IR from here.
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. The pattern is
/// expected to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
@ -55,132 +61,23 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
if (!transposeInputOp)
return matchFailure();
// Use the rewriter to perform the replacement
// Use the rewriter to perform the replacement.
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
return matchSuccess();
}
};
/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(ReshapeOp reshape,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current reshape.
ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
reshape.getOperand()->getDefiningOp());
// If the input is defined by another constant, bingo!
if (!constantOp)
return matchFailure();
auto reshapeType = reshape.getType().cast<ToyArrayType>();
if (auto valueAttr =
constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
// FIXME Check matching of element count!
// auto oldType = constantOp.getType();
auto newType = rewriter.getTensorType(
reshapeType.getShape(), valueAttr.getType().getElementType());
auto newAttr = valueAttr.reshape(newType);
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else if (auto valueAttr =
constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
// Broadcast
auto dataSize = std::accumulate(reshapeType.getShape().begin(),
reshapeType.getShape().end(), 1,
std::multiplies<int>());
std::vector<mlir::Attribute> data(dataSize, valueAttr);
auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
reshapeType.getElementType());
auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else {
llvm_unreachable("Unsupported Constant format");
}
return matchSuccess();
}
};
/// Fold reshape(reshape(x)) -> reshape(x)
struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(ReshapeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current reshape.
mlir::Value *reshapeInput = op.getOperand();
// If the input is defined by another reshape, bingo!
if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
return matchFailure();
// Use the rewriter to perform the replacement
rewriter.replaceOp(op, {reshapeInput});
return matchSuccess();
}
};
/// Fold reshape(x)) -> x, when input type matches output type
struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(ReshapeOp op,
mlir::PatternRewriter &rewriter) const override {
if (op.getOperand()->getType() != op.getType())
return matchFailure();
rewriter.replaceOp(op, {op.getOperand()});
return matchSuccess();
}
};
} // end anonymous namespace.
// Register our patterns for rewrite by the Canonicalization framework.
void TransposeOp::getCanonicalizationPatterns(
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
/// that they can be picked up by the Canonicalization framework.
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SimplifyRedundantTranspose>(context);
}
// Register our patterns for rewrite by the Canonicalization framework.
void ReshapeOp::getCanonicalizationPatterns(
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<SimplifyReshapeConstant, SimplifyReshapeReshape,
SimplifyNullReshape>(context);
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
/// that they can be picked up by the Canonicalization framework.
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
FoldConstantReshapeOptPattern>(context);
}
namespace {
/// Fold type.cast(x) -> x, when input type matches output type
struct SimplifyIdentityTypeCast : public mlir::OpRewritePattern<TypeCastOp> {
using mlir::OpRewritePattern<TypeCastOp>::OpRewritePattern;
mlir::PatternMatchResult
matchAndRewrite(TypeCastOp typeCast,
mlir::PatternRewriter &rewriter) const override {
auto resTy = typeCast.getType();
auto *candidateOp = typeCast.getOperation();
while (llvm::isa_and_nonnull<TypeCastOp>(candidateOp)) {
if (resTy == candidateOp->getOperand(0)->getType()) {
rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)});
return matchSuccess();
}
candidateOp = candidateOp->getOperand(0)->getDefiningOp();
}
return matchFailure();
}
};
} // end anonymous namespace.
void TypeCastOp::getCanonicalizationPatterns(
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<SimplifyIdentityTypeCast>(context);
}
} // namespace toy

View File

@ -0,0 +1,73 @@
//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Defines language-specific pattern match optimizations for Toy using
// Declarative Rewrite Rules (DRR) specified using TableGen records.
//
//===----------------------------------------------------------------------===//
#ifndef TOY_COMBINE
#define TOY_COMBINE
#ifndef OP_BASE
include "toy/Ops.td"
#endif // OP_BASE
/// Note: The DRR definition used for defining patterns is shown below:
///
/// class Pattern<
/// dag sourcePattern, list<dag> resultPatterns,
/// list<dag> additionalConstraints = [],
/// dag benefitsAdded = (addBenefit 0)
/// >;
//===----------------------------------------------------------------------===//
// Basic Pattern-Match and Rewrite
//===----------------------------------------------------------------------===//
// Reshape(Reshape(x)) = Reshape(x)
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
(ReshapeOp $arg)>;
//===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite using Native Code Call
//===----------------------------------------------------------------------===//
// Native Code Calls may be used for more complex transformations using inline
// C++ and C++ helper functions.
// Reshape(Constant(x)) = x'
def ReshapeConstant :
NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
def FoldConstantReshapeOptPattern : Pat<
(ReshapeOp:$res (ConstantOp $arg)),
(ConstantOp (ReshapeConstant $arg, $res))>;
//===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite with Constraints
//===----------------------------------------------------------------------===//
// DRR allows for constraint checking when the transformation is conditional
// on operand properties.
// Reshape(x) = x, where input and output shapes are identical
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
def RedundantReshapeOptPattern : Pat<
(ReshapeOp:$res $arg), (replaceWithValue $arg),
[(TypesAreIdentical $res, $arg)]>;
#endif // TOY_COMBINE

View File

@ -1,403 +0,0 @@
//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the dialect for the Toy IR: custom type parsing and
// operation verification.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
using llvm::ArrayRef;
using llvm::raw_ostream;
using llvm::raw_string_ostream;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
namespace toy {
namespace detail {
/// This class holds the implementation of the ToyArrayType.
/// It is intended to be uniqued based on its content and owned by the context.
struct ToyArrayTypeStorage : public mlir::TypeStorage {
/// This defines how we unique this type in the context: our key contains
/// only the shape, a more complex type would have multiple entries in the
/// tuple here.
/// The element of the tuples usually matches 1-1 the arguments from the
/// public `get()` method arguments from the facade.
using KeyTy = std::tuple<ArrayRef<int64_t>>;
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(std::get<0>(key));
}
/// When the key hash hits an existing type, we compare the shape themselves
/// to confirm we have the right type.
bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); }
/// This is a factory method to create our type storage. It is only
/// invoked after looking up the type in the context using the key and not
/// finding it.
static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
const KeyTy &key) {
// Copy the shape array into the bumpptr allocator owned by the context.
ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
// Allocate the instance for the ToyArrayTypeStorage itself
auto *storage = allocator.allocate<ToyArrayTypeStorage>();
// Initialize the instance using placement new.
return new (storage) ToyArrayTypeStorage(shape);
}
ArrayRef<int64_t> getShape() const { return shape; }
private:
ArrayRef<int64_t> shape;
/// Constructor is only invoked from the `construct()` method above.
ToyArrayTypeStorage(ArrayRef<int64_t> shape) : shape(shape) {}
};
} // namespace detail
mlir::Type ToyArrayType::getElementType() {
return mlir::FloatType::getF64(getContext());
}
ToyArrayType ToyArrayType::get(mlir::MLIRContext *context,
ArrayRef<int64_t> shape) {
return Base::get(context, ToyTypeKind::TOY_ARRAY, shape);
}
ArrayRef<int64_t> ToyArrayType::getShape() { return getImpl()->getShape(); }
mlir::MemRefType ToyArrayType::toMemref() {
auto memRefType = mlir::MemRefType::get(getShape(), getElementType(), {}, 0);
return memRefType;
}
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
addOperations<ConstantOp, GenericCallOp, PrintOp, TransposeOp, ReshapeOp,
MulOp, AddOp, ReturnOp, AllocOp, TypeCastOp>();
addTypes<ToyArrayType>();
}
/// Parse a type registered to this dialect, we expect only Toy arrays.
mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const {
// Sanity check: we only support array or array<...>
if (!tyData.startswith("array")) {
emitError(loc, "invalid Toy type '" + tyData + "', array expected");
return nullptr;
}
// Drop the "array" prefix from the type name, we expect either an empty
// string or just the shape.
tyData = tyData.drop_front(StringRef("array").size());
// This is the generic array case without shape, early return it.
if (tyData.empty())
return ToyArrayType::get(getContext());
// Use a regex to parse the shape (for efficient we should store this regex in
// the dialect itself).
SmallVector<StringRef, 4> matches;
auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
if (!shapeRegex.match(tyData, &matches)) {
emitError(loc, "invalid toy array shape '" + tyData + "'");
return nullptr;
}
SmallVector<int64_t, 4> shape;
// Iterate through the captures, skip the first one which is the full string.
for (auto dimStr :
llvm::make_range(std::next(matches.begin()), matches.end())) {
if (dimStr.startswith(","))
continue; // POSIX misses non-capturing groups.
if (dimStr.empty())
continue; // '*' makes it an optional group capture
// Convert the capture to an integer
unsigned long long dim;
if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
emitError(loc, "couldn't parse dimension as integer, matched: " + dimStr);
return mlir::Type();
}
shape.push_back(dim);
}
// Finally we collected all the dimensions in the shape,
// create the array type.
return ToyArrayType::get(getContext(), shape);
}
/// Print a Toy array type, for example `array<2, 3, 4>`
void ToyDialect::printType(mlir::Type type, raw_ostream &os) const {
auto arrayTy = type.dyn_cast<ToyArrayType>();
if (!arrayTy) {
os << "unknown toy type";
return;
}
os << "array";
if (!arrayTy.getShape().empty()) {
os << "<";
mlir::interleaveComma(arrayTy.getShape(), os);
os << ">";
}
}
////////////////////////////////////////////////////////////////////////////////
//////////////////// Custom Operations for the Dialect /////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Helper to verify that the result of an operation is a Toy array type.
template <typename T> static mlir::LogicalResult verifyToyReturnArray(T *op) {
if (!op->getResult()->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its argument, got "
<< op->getResult()->getType();
return op->emitOpError(os.str());
}
return mlir::success();
}
/// Helper to verify that the two operands of a binary operation are Toy
/// arrays..
template <typename T> static mlir::LogicalResult verifyToyBinOperands(T *op) {
if (!op->getOperand(0)->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its LHS, got "
<< op->getOperand(0)->getType();
return op->emitOpError(os.str());
}
if (!op->getOperand(1)->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its LHS, got "
<< op->getOperand(0)->getType();
return op->emitOpError(os.str());
}
return mlir::success();
}
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
ArrayRef<int64_t> shape, mlir::DenseElementsAttr value) {
state.types.push_back(ToyArrayType::get(builder->getContext(), shape));
auto dataAttribute = builder->getNamedAttr("value", value);
state.attributes.push_back(dataAttribute);
}
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::FloatAttr value) {
// Broadcast and forward to the other build factory
mlir::Type elementType = mlir::FloatType::getF64(builder->getContext());
auto dataType = builder->getTensorType({1}, elementType);
auto dataAttribute = builder->getDenseElementsAttr(dataType, {value})
.cast<mlir::DenseElementsAttr>();
ConstantOp::build(builder, state, {1}, dataAttribute);
}
/// Verifier for constant operation.
mlir::LogicalResult ConstantOp::verify() {
// Ensure that the return type is a Toy array
if (failed(verifyToyReturnArray(this)))
return mlir::failure();
// We expect the constant itself to be stored as an attribute.
auto dataAttr = getAttr("value").dyn_cast<mlir::DenseElementsAttr>();
if (!dataAttr) {
return emitOpError(
"missing valid `value` DenseElementsAttribute on toy.constant()");
}
auto attrType = dataAttr.getType().dyn_cast<mlir::TensorType>();
if (!attrType) {
return emitOpError(
"missing valid `value` DenseElementsAttribute on toy.constant()");
}
// If the return type of the constant is not a generic array, the shape must
// match the shape of the attribute holding the data.
auto resultType = getResult()->getType().cast<ToyArrayType>();
if (!resultType.isGeneric()) {
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("The rank of the toy.constant return type must match "
"the one of the attached value attribute: " +
Twine(attrType.getRank()) +
" != " + Twine(resultType.getRank()));
}
for (int dim = 0; dim < attrType.getRank(); ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
std::string msg;
raw_string_ostream os(msg);
return emitOpError(
"Shape mismatch between toy.constant return type and its "
"attribute at dimension " +
Twine(dim) + ": " + Twine(attrType.getShape()[dim]) +
" != " + Twine(resultType.getShape()[dim]));
}
}
}
return mlir::success();
}
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
StringRef callee, ArrayRef<mlir::Value *> arguments) {
// Generic call always returns a generic ToyArray initially
state.types.push_back(ToyArrayType::get(builder->getContext()));
state.operands.assign(arguments.begin(), arguments.end());
auto calleeAttr = builder->getStringAttr(callee);
state.attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
}
mlir::LogicalResult GenericCallOp::verify() {
// Verify that every operand is a Toy Array
for (int opId = 0, num = getNumOperands(); opId < num; ++opId) {
if (!getOperand(opId)->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its " << opId << " operand, got "
<< getOperand(opId)->getType();
return emitOpError(os.str());
}
}
return mlir::success();
}
/// Return the name of the callee.
StringRef GenericCallOp::getCalleeName() {
return getAttr("callee").cast<mlir::StringAttr>().getValue();
}
template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
if (!op->getOperand()->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its argument, got "
<< op->getOperand()->getType();
return op->emitOpError(os.str());
}
return mlir::success();
}
void ReturnOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
// Return does not return any value and has an optional single argument
if (value)
state.operands.push_back(value);
}
mlir::LogicalResult ReturnOp::verify() {
if (getNumOperands() > 1)
return emitOpError("expects zero or one operand, got " +
Twine(getNumOperands()));
if (hasOperand() && failed(verifyToySingleOperand(this)))
return mlir::failure();
return mlir::success();
}
void PrintOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
// Print does not return any value and has a single argument
state.operands.push_back(value);
}
mlir::LogicalResult PrintOp::verify() {
if (failed(verifyToySingleOperand(this)))
return mlir::failure();
return mlir::success();
}
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
state.types.push_back(ToyArrayType::get(builder->getContext()));
state.operands.push_back(value);
}
mlir::LogicalResult TransposeOp::verify() {
if (failed(verifyToySingleOperand(this)))
return mlir::failure();
return mlir::success();
}
void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value, ToyArrayType reshapedType) {
state.types.push_back(reshapedType);
state.operands.push_back(value);
}
mlir::LogicalResult ReshapeOp::verify() {
if (failed(verifyToySingleOperand(this)))
return mlir::failure();
auto retTy = getResult()->getType().dyn_cast<ToyArrayType>();
if (!retTy)
return emitOpError("toy.reshape is expected to produce a Toy array");
if (retTy.isGeneric())
return emitOpError("toy.reshape is expected to produce a shaped Toy array, "
"got a generic one.");
return mlir::success();
}
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.types.push_back(ToyArrayType::get(builder->getContext()));
state.operands.push_back(lhs);
state.operands.push_back(rhs);
}
mlir::LogicalResult AddOp::verify() {
if (failed(verifyToyBinOperands(this)))
return mlir::failure();
return mlir::success();
}
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.types.push_back(ToyArrayType::get(builder->getContext()));
state.operands.push_back(lhs);
state.operands.push_back(rhs);
}
mlir::LogicalResult MulOp::verify() {
if (failed(verifyToyBinOperands(this)))
return mlir::failure();
return mlir::success();
}
void AllocOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Type retType) {
state.types.push_back(retType);
}
void TypeCastOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value, mlir::Type destTy) {
state.operands.push_back(value);
state.types.push_back(destTy);
}
} // namespace toy

View File

@ -20,30 +20,23 @@
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "toy/Lowering.h"
#include "toy/MLIRGen.h"
#include "toy/Parser.h"
#include "toy/Passes.h"
#include "linalg1/Dialect.h"
#include "mlir/Analysis/Verifier.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
using namespace toy;
@ -64,28 +57,14 @@ static cl::opt<enum InputType> inputType(
"load the input file as an MLIR file")));
namespace {
enum Action {
None,
DumpAST,
DumpMLIR,
DumpMLIRLinalg,
DumpLLVMDialect,
DumpLLVMIR,
RunJIT
};
enum Action { None, DumpAST, DumpMLIR, DumpMLIRAffine };
}
static cl::opt<enum Action> emitAction(
"emit", cl::desc("Select the kind of output desired"),
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")),
cl::values(clEnumValN(DumpMLIRLinalg, "mlir-linalg",
"output the MLIR dump after linalg lowering")),
cl::values(clEnumValN(DumpLLVMDialect, "llvm-dialect",
"output the LLVM MLIR Dialect dump")),
cl::values(clEnumValN(DumpLLVMIR, "llvm-ir", "output the LLVM IR dump")),
cl::values(
clEnumValN(RunJIT, "jit",
"JIT the code and run it by invoking the main function")));
cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine",
"output the MLIR dump after affine lowering")));
static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
@ -103,174 +82,81 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.ParseModule();
}
mlir::LogicalResult optimize(mlir::ModuleOp module) {
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
// Apply any generic pass manager command line options.
applyPassManagerCLOptions(pm);
return pm.run(module);
}
mlir::LogicalResult lowerDialect(mlir::ModuleOp module, bool OnlyLinalg) {
mlir::PassManager pm(module.getContext());
pm.addPass(createEarlyLoweringPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
if (!OnlyLinalg) {
pm.addPass(createLateLoweringPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
}
// Apply any generic pass manager command line options.
applyPassManagerCLOptions(pm);
return pm.run(module);
}
mlir::OwningModuleRef loadFileAndProcessModule(
mlir::MLIRContext &context, bool EnableLinalgLowering = false,
bool EnableLLVMLowering = false, bool EnableOpt = false) {
mlir::OwningModuleRef module;
if (inputType == InputType::MLIR ||
llvm::StringRef(inputFilename).endswith(".mlir")) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
return nullptr;
}
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return nullptr;
}
if (failed(mlir::verify(*module))) {
llvm::errs() << "Error verifying MLIR module\n";
return nullptr;
}
} else {
int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// Handle '.toy' input to the compiler.
if (inputType != InputType::MLIR &&
!llvm::StringRef(inputFilename).endswith(".mlir")) {
auto moduleAST = parseInputFile(inputFilename);
module = mlirGen(context, *moduleAST);
return !module ? 1 : 0;
}
if (!module)
return nullptr;
if (EnableOpt) {
if (failed(optimize(*module))) {
llvm::errs() << "Module optimization failed\n";
return nullptr;
}
// Otherwise, the input is '.mlir'.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
return -1;
}
if (EnableLLVMLowering || EnableLinalgLowering) {
if (failed(lowerDialect(*module, !EnableLLVMLowering))) {
llvm::errs() << "Module lowering failed\n";
return nullptr;
}
// Parse the input mlir.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return 3;
}
return module;
return 0;
}
int dumpMLIR() {
// Register our Dialect with MLIR.
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
auto module =
loadFileAndProcessModule(context, /*EnableLinalgLowering=*/false,
/*EnableLLVMLowering=*/false, EnableOpt);
if (!module)
return -1;
mlir::OwningModuleRef module;
if (int error = loadMLIR(context, module))
return error;
mlir::PassManager pm(&context);
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);
// Check to see what granularity of MLIR we are compiling to.
bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine;
if (EnableOpt || isLoweringToAffine) {
// Inline all functions into main and then delete them.
pm.addPass(mlir::createInlinerPass());
pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
// Now that there is only one function, we can infer the shapes of each of
// the operations.
pm.addPass(mlir::toy::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
}
if (isLoweringToAffine) {
// Partially lower the toy dialect with a few cleanups afterwards.
pm.addPass(mlir::toy::createLowerToAffinePass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
// Add optimizations if enabled.
if (EnableOpt) {
pm.addPass(mlir::createLoopFusionPass());
pm.addPass(mlir::createMemRefDataFlowOptPass());
}
}
if (mlir::failed(pm.run(*module)))
return 4;
module->dump();
return 0;
}
int dumpMLIRLinalg() {
mlir::MLIRContext context;
auto module = loadFileAndProcessModule(context, /*EnableLinalgLowering=*/true,
/*EnableLLVMLowering=*/false,
/* EnableOpt=*/true);
if (!module)
return -1;
module->dump();
return 0;
}
int dumpLLVMDialect() {
mlir::MLIRContext context;
auto module = loadFileAndProcessModule(
context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true,
/* EnableOpt=*/true);
if (!module) {
llvm::errs() << "Failed to load/lower MLIR module\n";
return -1;
}
module->dump();
return 0;
}
int dumpLLVMIR() {
mlir::MLIRContext context;
auto module = loadFileAndProcessModule(
context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true,
/* EnableOpt=*/true);
if (!module) {
llvm::errs() << "Failed to load/lower MLIR module\n";
return -1;
}
auto llvmModule = translateModuleToLLVMIR(*module);
if (!llvmModule) {
llvm::errs() << "Failed to emit LLVM IR\n";
return -1;
}
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
auto optPipeline = mlir::makeOptimizingTransformer(
/* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0,
/* targetMachine=*/nullptr);
if (auto err = optPipeline(llvmModule.get())) {
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
return -1;
}
llvm::errs() << *llvmModule << "\n";
return 0;
}
int runJit() {
mlir::MLIRContext context;
auto module = loadFileAndProcessModule(
context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true,
/* EnableOpt=*/true);
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
// Create an MLIR execution engine. The execution engine eagerly JIT-compiles
// the module.
auto optPipeline = mlir::makeOptimizingTransformer(
/* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0,
/* targetMachine=*/nullptr);
auto maybeEngine = mlir::ExecutionEngine::create(*module, optPipeline);
assert(maybeEngine && "failed to construct an execution engine");
auto &engine = maybeEngine.get();
// Invoke the JIT-compiled function with the arguments. Note that, for API
// uniformity reasons, it takes a list of type-erased pointers to arguments.
auto invocationResult = engine->invoke("main");
if (invocationResult) {
llvm::errs() << "JIT invocation failed\n";
return -1;
}
return 0;
}
int dumpAST() {
if (inputType == InputType::MLIR) {
llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
@ -286,10 +172,6 @@ int dumpAST() {
}
int main(int argc, char **argv) {
// Register our Dialects with MLIR
mlir::registerDialect<ToyDialect>();
mlir::registerDialect<linalg::LinalgDialect>();
mlir::registerPassManagerCLOptions();
cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
@ -297,18 +179,10 @@ int main(int argc, char **argv) {
case Action::DumpAST:
return dumpAST();
case Action::DumpMLIR:
case Action::DumpMLIRAffine:
return dumpMLIR();
case Action::DumpMLIRLinalg:
return dumpMLIRLinalg();
case Action::DumpLLVMDialect:
return dumpLLVMDialect();
case Action::DumpLLVMIR:
return dumpLLVMIR();
case Action::RunJIT:
return runJit();
default:
llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
return -1;
}
return 0;

View File

@ -1,4 +1,4 @@
# Chapter 5: CodeGen via Lowering to Lower-Level Dialects
# Chapter 5 - Partial Lowering to Lower-Level Dialects for Optimization
At this point, we are eager to generate actual code and see our Toy language
taking life. We will obviously use LLVM to generate code, but just showing the
@ -6,293 +6,356 @@ LLVM builder interface wouldn't be very exciting here. Instead, we will show how
to perform progressive lowering through a mix of dialects coexisting in the same
function.
To make it more interesting, we will consider that we want to reuse existing
optimizations implemented in a dialect optimizing linear algebra: `Linalg`. This
dialect is tailored to the computation heavy part of the program, and is
limited: it doesn't support representing our `toy.print` builtin for instance,
neither should it! Instead we can target `Linalg` for the computation heavy part
of Toy (mostly matmul), we will target the `Affine` dialect for other
well-formed loop nest, and directly the `LLVM IR` dialect for lowering `print`.
To make it more interesting, in this chapter we will will consider that we want
to reuse existing optimizations implemented in a dialect optimizing affine
transformations: `Affine`. This dialect is tailored to the computation-heavy
part of the program, and is limited: it doesn't support representing our
`toy.print` builtin for instance, neither should it! Instead we can target
`Affine` for the computation heavy part of Toy, and in the
[next chapter](Ch-6.md) directly the `LLVM IR` dialect for lowering `print`. As
part of this lowering, we will be lowering from the
[TensorType](../../LangRef.md#tensor-type), that `Toy` operates on, to the
[MemRefType](../../LangRef.md#memref-type) that is indexed via an affine
loop-nest. Tensors represent an abstract value-typed sequence of data, meaning
that they don't live in any memory. MemRefs on the other hand represent lower
level buffer access, as they are concrete references to a region of memory.
# The `DialectConversion` Framework
# Dialect Conversions
Similarly to the canonicalization patterns introduced in the previous section,
the `DialectConversion` framework involves its own set of patterns. This
framework operates a bit differently from the canonicalizer: a new function is
created and the pattern matching operation in the original function are expected
to emit the IR in the new function.
MLIR contains many different dialects, so it is important to have a unified
framework for converting between them. This is where the `DialectConversion`
framework comes into play. This framework allows for transforming a set of
`illegal` operations to a set of `legal` ones. To use this framework we need to
provide two things:
Dialect conversion requires three components, implemented by overriding virtual
methods defined in `DialectConversion`:
* A [Conversion Target](../../DialectConversion.md#conversion-target)
- Type Conversion: for things like block arguments' type.
- Function signature conversion: for every function it is invoked with the
function type and the conversion generates a new prototype for the converted
function. The default implementation will call into the type conversion for
the returned values and for each of the parameters.
- Operations conversions: each pattern is expected to generate new results
matching the current operations' in the new function. This may involve
generating one or multiple new operations, or possibly just remapping
existing operands (folding).
- This is the formal specification of what operations, or dialects, are
legal for the conversion. Operations that aren't legal will require
rewrite patterns to perform legalization.
A typical starting point for implementing our lowering would be:
* A set of
[Rewrite Patterns](../../DialectConversion.md#rewrite-pattern-specification)
- These are the set of [patterns](../../QuickstartRewrites.md) used to
convert `illegal` operations into a set of zero or more `legal` ones.
* Optionally, A [Type Converter](../../DialectConversion.md#type-conversion).
- If provided, this is used to convert the types of block arguments. We
won't be needing this for our conversion.
## Conversion Target
For our purposes, we want to convert the compute intensive `Toy` operations into
a combination of operations from the `Affine` `Standard` dialects for further
optimization. To start off the lowering, we first define our conversion target:
```c++
class Lowering : public DialectConversion {
public:
// This gets called for block and region arguments, and attributes.
Type convertType(Type t) override { /*...*/ }
void ToyToAffineLoweringPass::runOnFunction() {
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
mlir::ConversionTarget target(getContext());
// This gets called for functions.
FunctionType convertFunctionSignatureType(FunctionType type,
ArrayRef<NamedAttributeList> argAttrs,
SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) { /*...*/ }
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine` and `Standard` dialects.
target.addLegalDialect<mlir::AffineOpsDialect, mlir::StandardOpsDialect>();
// This gets called once to set up operation converters.
llvm::DenseSet<ConversionPattern *>
initConverters(MLIRContext *context) override {
RewriteListBuilder<MulOpConversion, PrintOpConversion,
TransposeOpConversion>::build(allocator, context);
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
// a partial lowering, we explicitly mark the Toy operations that don't want
// to lower, `toy.print`, as `legal`.
target.addIllegalDialect<ToyDialect>();
target.addLegalOp<PrintOp>();
...
}
```
## Conversion Patterns
After the conversion target has been defined, we can define how to convert the
`illegal` operations into `legal` ones. Similarly to the canonicalization
framework introduced in [chapter 3](Ch-3.md), the
[`DialectConversion` framework](../../DialectConversion.md) also uses
[RewritePatterns](../../QuickstartRewrites.md) to perform the conversion logic.
These patterns may be the `RewritePatterns` seen before, or a new type of
pattern specific to the conversion framework `ConversionPattern`.
`ConversionPatterns` are different from traditional `RewritePatterns` in that
they accept an additional `operands` parameter containing operands that have
been remapped/replaced. This is used when dealing with type conversions as the
pattern will want to operand on values of the new type, but match against the
old. For our lowering, this invariant will be useful during our lowering as we
will be translating from the [TensorType](../../LangRef.md#tensor-type),
currently being operated on, to the [MemRefType](../../LangRef.md#memref-type).
Let's look at a snippet of lowering the `toy.transpose` operation:
```c++
/// Lower the `toy.transpose` operation to an affine loop nest.
struct TransposeOpLowering : public mlir::ConversionPattern {
TransposeOpLowering(mlir::MLIRContext *ctx)
: mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}
/// Match and rewrite the given `toy.transpose` operation, with the given
/// operands that have been remapped from `tensor<...>` to `memref<...>`.
mlir::PatternMatchResult
matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value *> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
// Call to a helper function that will lower the current operation to a set
// of affine loops. We provide a functor that operates on the remapped
// operands, as well as the loop induction variables for the inner most
// loop body.
lowerOpToLoops(
op, operands, rewriter,
[loc](mlir::PatternRewriter &rewriter,
ArrayRef<mlir::Value *> memRefOperands,
ArrayRef<mlir::Value *> loopIvs) {
// Generate an adaptor for the remapped operands of the TransposeOp.
// This allows for using the nice named accessors that are generated
// by the ODS. This adaptor is automatically provided by the ODS
// framework.
TransposeOpOperandAdaptor tranposeAdaptor(memRefOperands);
mlir::Value *input = tranposeAdaptor.input();
// Transpose the elements by generating a load from the reverse
// indices.
SmallVector<mlir::Value *, 2> reverseIvs(llvm::reverse(loopIvs));
return rewriter.create<mlir::AffineLoadOp>(loc, input, reverseIvs);
});
return matchSuccess();
}
private:
llvm::BumpPtrAllocator allocator;
};
```
Individual operation converters are following this pattern:
Now we can prepare the list of patterns to use during the lowering process:
```c++
/// Lower a toy.add to an affine loop nest.
///
/// This class inherit from `ConversionPattern` and override `rewrite`,
/// similarly to the PatternRewriter introduced in the previous chapter.
/// It will be called by the DialectConversion framework (see `LateLowering`
/// class below).
class AddOpConversion : public ConversionPattern {
public:
explicit AddOpConversion(MLIRContext *context)
: ConversionPattern(toy::AddOp::getOperationName(), 1, context) {}
void ToyToAffineLoweringPass::runOnFunction() {
...
/// Lower the `op` by generating IR using the `rewriter` builder. The builder
/// is setup with a new function, the `operands` array has been populated with
/// the rewritten operands for `op` in the new function.
/// The results created by the new IR with the builder are returned, and their
/// number must match the number of result of `op`.
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
OpBuilder &rewriter) const override {
...
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
mlir::OwningRewritePatternList patterns;
patterns.insert<..., TransposeOpLowering>(&getContext());
// Return the newly allocated buffer, it will be used as an operand when
// converting the operations corresponding to the users of this `toy.add`.
return result;
}
...
```
## Linalg
## Partial Lowering
Linalg is an advanced dialect for dense algebra optimizations. It is implemented
as [a separate tutorial](../Linalg/Ch-1.md) in parallel with Toy. We are acting
as a user of this dialect by lowering Toy matrix multiplications to
`linalg.matmul`.
Once the patterns have been defined, we can perform the actual lowering. The
`DialectConversion` framework provides several different modes of lowering, but
for our purposes we will be performing a partial lowering, as we will not be
converting `toy.print` at this time.
To support this, we will split our lowering in two parts: an *early lowering*
that emits operations in the `Linalg` dialect for a subset of the Toy IR, and a
*late lowering* that materializes buffers and converts all operations and type
to the LLVM dialect. We will then be able to run specific optimizations in
between the two lowering.
```c++
void ToyToAffineLoweringPass::runOnFunction() {
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
mlir::ConversionTarget target(getContext());
Let's look again at our example `multiply_transpose`:
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine` and `Standard` dialects.
target.addLegalDialect<mlir::AffineOpsDialect, mlir::StandardOpsDialect>();
```mlir
func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array)
attributes {toy.generic: true} {
%0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array
%1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array
"toy.return"(%1) : (!toy.array) -> ()
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
// a partial lowering, we explicitly mark the Toy operations that don't want
// to lower, `toy.print`, as `legal`.
target.addIllegalDialect<ToyDialect>();
target.addLegalOp<PrintOp>();
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
mlir::OwningRewritePatternList patterns;
patterns.insert<..., TransposeOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
auto function = getFunction();
if (mlir::failed(mlir::applyPartialConversion(function, target, patterns)))
signalPassFailure();
}
```
After shape inference, and lowering to `Linalg`, here is what our IR will look
like:
### Design Considerations With Partial Lowering
Before diving into the result of our lowering, this is a good time to discuss
potential design considerations when it comes to partial lowering. In our
lowering, we will be transforming from a value-type, TensorType, to a
allocated(buffer-like) type, MemRefType. Given that we will not be lowering the
`toy.print` operation, we need to temporarily bridge these two worlds. There are
many ways to go about this, each with their own tradeoffs:
* Generate `load` operations from the buffer
One option is to generate `load` operations from the buffer type to materialize
an instance of the value type. This allows for the definition of the `toy.print`
operation to remain unchanged. The downside to this approach is that the
optimizations on the `affine` dialect are limited, because the `load` will
actually involve a full copy that is only visible *after* our optimizations have
been performed.
* Generate a new version of `toy.print` that operates on the lowered type
Another option would be to have another, lowered, variant of `toy.print` that
operates on the lowered type. The benefit of this option is that there is no
hidden, unnecessary, copy to optimizer. The downside is that another operation
definition is needed, that may duplicate many aspects of the first. Defining a
base class in [ODS](../../OpDefinitions.md) may simplify this, but you still
need to treat these operations separately.
* Update `toy.print` to allow for operating on the lowered type
A third option is to update the current definition of `toy.print` to allow for
operating the on the lowered type. The benefit of this approach is that it is
simple, does not introduce an additional hidden copy, and does not require
another operation definition. The downside to this option is that it requires
mixing abstraction levels in the `Toy` dialect.
For the sake of simplicity, we will use the third option for this lowering. This
involves updating the type constraints on the PrintOp in the operation
definition file:
```tablegen
def PrintOp : Toy_Op<"print"> {
...
// The print operation takes an input tensor to print.
// We also allow a F64MemRef to enable interop during partial lowering.
let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
}
```
## Complete Toy Example
Looking back at our current working example:
```.mlir
func @main() {
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
%2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
%3 = "toy.mul"(%2, %2) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>
"toy.print"(%3) : (tensor<2x3xf64>) -> ()
"toy.return"() : () -> ()
}
```
With affine lowering added to our pipeline, we can now generate:
```mlir
func @multiply_transpose_2x3_2x3(%arg0: !toy.array<2, 3>, %arg1: !toy.array<2, 3>) -> !toy.array<2, 2>
attributes {toy.generic: false} {
%c3 = constant 3 : index
func @main() {
%c0 = constant 0 : index
%c2 = constant 2 : index
%c1 = constant 1 : index
%0 = "toy.transpose"(%arg1) : (!toy.array<2, 3>) -> !toy.array<3, 2>
%1 = "toy.alloc"() : () -> !toy.array<2, 2>
%2 = "toy.cast"(%1) : (!toy.array<2, 2>) -> memref<2x2xf64>
%3 = "toy.cast"(%arg0) : (!toy.array<2, 3>) -> memref<2x3xf64>
%4 = "toy.cast"(%0) : (!toy.array<3, 2>) -> memref<3x2xf64>
%5 = linalg.range %c0:%c2:%c1 : !linalg.range
%6 = linalg.range %c0:%c3:%c1 : !linalg.range
%7 = linalg.view %3[%5, %6] : !linalg<"view<?x?xf64>">
%8 = linalg.view %4[%6, %5] : !linalg<"view<?x?xf64>">
%9 = linalg.view %2[%5, %5] : !linalg<"view<?x?xf64>">
linalg.matmul(%7, %8, %9) : !linalg<"view<?x?xf64>">
"toy.return"(%1) : (!toy.array<2, 2>) -> ()
}
```
%c2 = constant 2 : index
%cst = constant 1.000000e+00 : f64
%cst_0 = constant 2.000000e+00 : f64
%cst_1 = constant 3.000000e+00 : f64
%cst_2 = constant 4.000000e+00 : f64
%cst_3 = constant 5.000000e+00 : f64
%cst_4 = constant 6.000000e+00 : f64
Note how the operations from multiple dialects are coexisting in this function.
// Allocating buffers for the inputs and outputs.
%0 = alloc() : memref<2x3xf64>
%1 = alloc() : memref<2x3xf64>
%2 = alloc() : memref<2x3xf64>
You can reproduce this result with `bin/toyc-ch5
test/Examples/Toy/Ch5/lowering.toy -emit=mlir-linalg`
// Initialize the input buffer with the constant values.
affine.store %cst, %2[%c0, %c0] : memref<2x3xf64>
affine.store %cst_0, %2[%c0, %c1] : memref<2x3xf64>
affine.store %cst_1, %2[%c0, %c2] : memref<2x3xf64>
affine.store %cst_2, %2[%c1, %c0] : memref<2x3xf64>
affine.store %cst_3, %2[%c1, %c1] : memref<2x3xf64>
affine.store %cst_4, %2[%c1, %c2] : memref<2x3xf64>
## Emitting LLVM
The availability of various dialects allows for a smooth lowering by reducing
the impedance mismatch between dialects. For example we don't need to lower our
`toy.print` over array directly to LLVM IR, we can use the well structured loop
from the `Affine` dialect for convenience when scanning the array and insert a
call to `llvm.printf` in the body. We will rely on MLIR lowering to LLVM for the
`Affine` dialect, we get it for free. Here is a simplified version of the code
in this chapter for lowering `toy.print`:
```c++
// Create our loop nest now
using namespace edsc;
using llvmCall = intrinsics::ValueBuilder<LLVM::CallOp>;
ScopedContext scope(rewriter, loc);
ValueHandle zero = intrinsics::constant_index(0);
ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f "));
ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n"));
MemRefView vOp(operand);
IndexedValue iOp(operand);
IndexHandle i, j, M(vOp.ub(0)), N(vOp.ub(1));
LoopBuilder(&i, zero, M, 1)({
LoopBuilder(&j, zero, N, 1)({
llvmCall(retTy,
rewriter.getSymbolRefAttr(printfFunc),
{fmtCst, iOp(i, j)})
}),
llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {fmtEol})
});
```
For instance the Toy IR may contain:
```
"toy.print"(%0) : (!toy.array<2, 2>) -> ()
```
which the converter above will turn into this sequence:
```mlir
affine.for %i0 = 0 to 2 {
affine.for %i1 = 0 to 2 {
%3 = load %0[%i0, %i1] : memref<2x2xf64>
%4 = llvm.call @printf(%1, %3) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
// Load the transpose value from the input buffer and store it into the
// next input buffer.
affine.for %arg0 = 0 to 2 {
affine.for %arg1 = 0 to 3 {
%3 = affine.load %2[%arg1, %arg0] : memref<2x3xf64>
affine.store %3, %1[%arg0, %arg1] : memref<2x3xf64>
}
%5 = llvm.call @printf(%2, %cst_21) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
}
// Multiply and store into the output buffer.
affine.for %arg0 = 0 to 2 {
affine.for %arg1 = 0 to 3 {
%3 = affine.load %1[%arg0, %arg1] : memref<2x3xf64>
%4 = affine.load %1[%arg0, %arg1] : memref<2x3xf64>
%5 = mulf %3, %4 : f64
affine.store %5, %0[%arg0, %arg1] : memref<2x3xf64>
}
}
// Print the value held by the buffer.
"toy.print"(%0) : (memref<2x3xf64>) -> ()
dealloc %2 : memref<2x3xf64>
dealloc %1 : memref<2x3xf64>
dealloc %0 : memref<2x3xf64>
return
}
```
Note the mix of a loop nest in the `Affine` dialect, with an operation
`llvm.call` in the body. MLIR knows already how to lower this to:
## Taking Advantage of Affine Optimization
Our naive lowering is correct, but it leaves a lot to be desired in regards to
efficiency; For example the lowering of `toy.mul` has generated some redundant
loads. Let's look at how adding a few existing optimizations to the pipeline can
help clean this up. Adding the `LoopFusion` and `MemRefDataFlowOpt` passes to
the pipeline gives the following result:
```mlir
llvm.br ^bb1(%87 : !llvm.i64)
^bb1(%89: !llvm.i64): // 2 preds: ^bb0, ^bb5
%90 = llvm.icmp "slt" %89, %88 : !llvm.i64
llvm.cond_br %90, ^bb2, ^bb6
^bb2: // pred: ^bb1
%91 = llvm.mlir.constant(0 : index) : !llvm.i64
%92 = llvm.mlir.constant(2 : index) : !llvm.i64
llvm.br ^bb3(%91 : !llvm.i64)
^bb3(%93: !llvm.i64): // 2 preds: ^bb2, ^bb4
%94 = llvm.icmp "slt" %93, %92 : !llvm.i64
llvm.cond_br %94, ^bb4, ^bb5
^bb4: // pred: ^bb3
%95 = llvm.mlir.constant(2 : index) : !llvm.i64
%96 = llvm.mlir.constant(2 : index) : !llvm.i64
%97 = llvm.mul %89, %96 : !llvm.i64
%98 = llvm.add %97, %93 : !llvm.i64
%99 = llvm.getelementptr %6[%98] : (!llvm<"double*">, !llvm.i64) -> !llvm<"double*">
%100 = llvm.load %99 : !llvm<"double*">
%101 = llvm.call @printf(%48, %100) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
%102 = llvm.mlir.constant(1 : index) : !llvm.i64
%103 = llvm.add %93, %102 : !llvm.i64
llvm.br ^bb3(%103 : !llvm.i64)
^bb5: // pred: ^bb3
%104 = llvm.call @printf(%76, %71) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
%105 = llvm.mlir.constant(1 : index) : !llvm.i64
%106 = llvm.add %89, %105 : !llvm.i64
llvm.br ^bb1(%106 : !llvm.i64)
```
func @main() {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%cst = constant 1.000000e+00 : f64
%cst_0 = constant 2.000000e+00 : f64
%cst_1 = constant 3.000000e+00 : f64
%cst_2 = constant 4.000000e+00 : f64
%cst_3 = constant 5.000000e+00 : f64
%cst_4 = constant 6.000000e+00 : f64
We appreciate the ease to generate the former, as well as the readability!
// Allocating buffers for the inputs and outputs.
%0 = alloc() : memref<2x3xf64>
%1 = alloc() : memref<2x3xf64>
You may reproduce these results with `echo "def main() { print([[1,2],[3,4]]); }
" | bin/toyc-ch5 -x toy - -emit=llvm-dialect` and `echo "def main() {
print([[1,2],[3,4]]); } " | bin/toyc-ch5 -x toy - -emit=llvm-ir`.
// Initialize the input buffer with the constant values.
affine.store %cst, %1[%c0, %c0] : memref<2x3xf64>
affine.store %cst_0, %1[%c0, %c1] : memref<2x3xf64>
affine.store %cst_1, %1[%c0, %c2] : memref<2x3xf64>
affine.store %cst_2, %1[%c1, %c0] : memref<2x3xf64>
affine.store %cst_3, %1[%c1, %c1] : memref<2x3xf64>
affine.store %cst_4, %1[%c1, %c2] : memref<2x3xf64>
# CodeGen: Getting Out of MLIR
affine.for %arg0 = 0 to 2 {
affine.for %arg1 = 0 to 3 {
// Load the transpose value from the input buffer.
%2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64>
At this point, all the IR is expressed in the LLVM dialect, MLIR can perform a
straight conversion to an LLVM module. You may look into
[`Ch5/toyc.cpp`](../../../examples/toy/Ch5/toyc.cpp) for the `dumpLLVM()`
function:
```c++
int dumpLLVM() {
mlir::MLIRContext context;
auto module = loadFileAndProcessModule(context, /* EnableLowering=*/ true);
auto llvmModule = translateModuleToLLVMIR(*module);
if (!llvmModule) {
llvm::errs() << "Failed to emit LLVM IR\n";
return -1;
// Multiply and store into the output buffer.
%3 = mulf %2, %2 : f64
affine.store %3, %0[%arg0, %arg1] : memref<2x3xf64>
}
}
llvm::errs() << *llvmModule << "\n";
return 0;
// Print the value held by the buffer.
"toy.print"(%0) : (memref<2x3xf64>) -> ()
dealloc %1 : memref<2x3xf64>
dealloc %0 : memref<2x3xf64>
return
}
```
Adding a JIT isn't much more involved either:
Here we can see that an allocation was removed, the two loop nests were fused,
and we also were able to remove an unnecessary allocation! You can build
`toyc-ch5` and try yourself: `toyc-ch5 test/lowering.toy -emit=mlir-affine`. We
can also check our optimizations by adding `-opt`.
```c++
int runJit() {
mlir::MLIRContext context;
auto module = loadFileAndProcessModule(context, /* EnableLowering=*/ true);
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
// Create an MLIR execution engine. Note that it takes a null pass manager
// to make sure it won't run "default" passes on the MLIR that would trigger
// a second conversion to LLVM IR. The execution engine eagerly JIT-compiles
// the module.
auto maybeEngine =
mlir::ExecutionEngine::create(module.get(), /*pm=*/nullptr);
assert(maybeEngine && "failed to construct an execution engine");
auto &engine = maybeEngine.get();
// Invoke the JIT-compiled function with the arguments. Note that, for API
// uniformity reasons, it takes a list of type-erased pointers to arguments.
auto invocationResult = engine->invoke("main");
if(invocationResult) {
llvm::errs() << "JIT invocation failed\n";
return -1;
}
return 0;
}
```
You can play with it, from the build directory:
```bash
$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch5 -emit=jit
1.000000 2.000000
3.000000 4.000000
```
You can also play with `-emit=mlir`, `-emit=mlir-linalg`, `-emit=llvm-dialect`,
and `-emit=llvm-ir` to compare the various level of IR involved. Try also
options like `--print-ir-after-all` to track the evolution of the IR throughout
the pipeline.
In this chapter we explored some aspects of partial lowering, with the intent to
optimize. In the [next chapter](Ch-6.md) we will continue the discussion about
dialect conversion by targeting LLVM for code generation.

View File

@ -1,7 +1,6 @@
# RUN: toyc-ch4 %s -emit=ast 2>&1 | FileCheck %s
# User defined generic function that operates solely on
# User defined generic function that operates on unknown shaped arguments.
def multiply_transpose(a, b) {
return a * transpose(b);
}
@ -11,8 +10,10 @@ def main() {
# The shape is inferred from the supplied literal.
var a = [[1, 2, 3], [4, 5, 6]];
# b is identical to a, the literal array is implicitly reshaped: defining new
# variables is the way to reshape arrays (element count must match).
# variables is the way to reshape arrays (element count in literal must match
# the size of specified shape).
var b<2, 3> = [1, 2, 3, 4, 5, 6];
# This call will specialize `multiply_transpose` with <2, 3> for both
# arguments and deduce a return type of <2, 2> in initialization of `c`.
var c = multiply_transpose(a, b);
@ -30,44 +31,44 @@ def main() {
# CHECK: Module:
# CHECK-NEXT: Function
# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:5:1'
# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1'
# CHECK-NEXT: Params: [a, b]
# CHECK-NEXT: Block {
# CHECK-NEXT: Retur
# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:6:14
# CHECK-NEXT: var: a @{{.*}}ast.toy:6:10
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:6:14
# CHECK-NEXT: var: b @{{.*}}ast.toy:6:24
# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:14
# CHECK-NEXT: var: a @{{.*}}ast.toy:5:10
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:14
# CHECK-NEXT: var: b @{{.*}}ast.toy:5:24
# CHECK-NEXT: ]
# CHECK-NEXT: } // Block
# CHECK-NEXT: Function
# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:9:1'
# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1'
# CHECK-NEXT: Params: []
# CHECK-NEXT: Block {
# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:12:3
# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:12:11
# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3
# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11
# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3
# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17
# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:18:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:18:11
# CHECK-NEXT: var: a @{{.*}}ast.toy:18:30
# CHECK-NEXT: var: b @{{.*}}ast.toy:18:33
# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11
# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30
# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33
# CHECK-NEXT: ]
# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:21:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:21:11
# CHECK-NEXT: var: b @{{.*}}ast.toy:21:30
# CHECK-NEXT: var: a @{{.*}}ast.toy:21:33
# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11
# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30
# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33
# CHECK-NEXT: ]
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:24:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:24:11
# CHECK-NEXT: var: b @{{.*}}ast.toy:24:30
# CHECK-NEXT: var: c @{{.*}}ast.toy:24:33
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11
# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30
# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33
# CHECK-NEXT: ]
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:27:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:27:11
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:27:30
# CHECK-NEXT: var: a @{{.*}}ast.toy:27:40
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:28:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30
# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40
# CHECK-NEXT: ]
# CHECK-NEXT: var: c @{{.*}}ast.toy:27:44
# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44
# CHECK-NEXT: ]

View File

@ -0,0 +1,65 @@
// RUN: toyc-ch5 %s -emit=mlir-affine 2>&1 | FileCheck %s
// RUN: toyc-ch5 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT
func @main() {
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
%2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
%3 = "toy.mul"(%2, %2) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>
"toy.print"(%3) : (tensor<2x3xf64>) -> ()
"toy.return"() : () -> ()
}
// CHECK-LABEL: func @main()
// CHECK: [[VAL_0:%.*]] = constant 1.000000e+00 : f64
// CHECK: [[VAL_1:%.*]] = constant 2.000000e+00 : f64
// CHECK: [[VAL_2:%.*]] = constant 3.000000e+00 : f64
// CHECK: [[VAL_3:%.*]] = constant 4.000000e+00 : f64
// CHECK: [[VAL_4:%.*]] = constant 5.000000e+00 : f64
// CHECK: [[VAL_5:%.*]] = constant 6.000000e+00 : f64
// CHECK: [[VAL_6:%.*]] = alloc() : memref<2x3xf64>
// CHECK: [[VAL_7:%.*]] = alloc() : memref<2x3xf64>
// CHECK: [[VAL_8:%.*]] = alloc() : memref<2x3xf64>
// CHECK: affine.store [[VAL_0]], [[VAL_8]][0, 0] : memref<2x3xf64>
// CHECK: affine.store [[VAL_1]], [[VAL_8]][0, 1] : memref<2x3xf64>
// CHECK: affine.store [[VAL_2]], [[VAL_8]][0, 2] : memref<2x3xf64>
// CHECK: affine.store [[VAL_3]], [[VAL_8]][1, 0] : memref<2x3xf64>
// CHECK: affine.store [[VAL_4]], [[VAL_8]][1, 1] : memref<2x3xf64>
// CHECK: affine.store [[VAL_5]], [[VAL_8]][1, 2] : memref<2x3xf64>
// CHECK: affine.for [[VAL_9:%.*]] = 0 to 2 {
// CHECK: affine.for [[VAL_10:%.*]] = 0 to 3 {
// CHECK: [[VAL_11:%.*]] = affine.load [[VAL_8]]{{\[}}[[VAL_10]], [[VAL_9]]] : memref<2x3xf64>
// CHECK: affine.store [[VAL_11]], [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_10]]] : memref<2x3xf64>
// CHECK: affine.for [[VAL_12:%.*]] = 0 to 2 {
// CHECK: affine.for [[VAL_13:%.*]] = 0 to 3 {
// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<2x3xf64>
// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<2x3xf64>
// CHECK: [[VAL_16:%.*]] = mulf [[VAL_14]], [[VAL_15]] : f64
// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<2x3xf64>
// CHECK: "toy.print"([[VAL_6]]) : (memref<2x3xf64>) -> ()
// CHECK: dealloc [[VAL_8]] : memref<2x3xf64>
// CHECK: dealloc [[VAL_7]] : memref<2x3xf64>
// CHECK: dealloc [[VAL_6]] : memref<2x3xf64>
// OPT-LABEL: func @main()
// OPT: [[VAL_1:%.*]] = constant 1.000000e+00 : f64
// OPT: [[VAL_2:%.*]] = constant 2.000000e+00 : f64
// OPT: [[VAL_3:%.*]] = constant 3.000000e+00 : f64
// OPT: [[VAL_4:%.*]] = constant 4.000000e+00 : f64
// OPT: [[VAL_5:%.*]] = constant 5.000000e+00 : f64
// OPT: [[VAL_6:%.*]] = constant 6.000000e+00 : f64
// OPT: [[VAL_7:%.*]] = alloc() : memref<2x3xf64>
// OPT: [[VAL_8:%.*]] = alloc() : memref<2x3xf64>
// OPT: affine.store [[VAL_1]], [[VAL_8]]{{\[}}0, 0] : memref<2x3xf64>
// OPT: affine.store [[VAL_2]], [[VAL_8]]{{\[}}0, 1] : memref<2x3xf64>
// OPT: affine.store [[VAL_3]], [[VAL_8]]{{\[}}0, 2] : memref<2x3xf64>
// OPT: affine.store [[VAL_4]], [[VAL_8]]{{\[}}1, 0] : memref<2x3xf64>
// OPT: affine.store [[VAL_5]], [[VAL_8]]{{\[}}1, 1] : memref<2x3xf64>
// OPT: affine.store [[VAL_6]], [[VAL_8]]{{\[}}1, 2] : memref<2x3xf64>
// OPT: affine.for [[VAL_9:%.*]] = 0 to 2 {
// OPT: affine.for [[VAL_10:%.*]] = 0 to 3 {
// OPT: [[VAL_11:%.*]] = affine.load [[VAL_8]]{{\[}}[[VAL_10]], [[VAL_9]]] : memref<2x3xf64>
// OPT: [[VAL_12:%.*]] = mulf [[VAL_11]], [[VAL_11]] : f64
// OPT: affine.store [[VAL_12]], [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_10]]] : memref<2x3xf64>
// OPT: "toy.print"([[VAL_7]]) : (memref<2x3xf64>) -> ()
// OPT: dealloc [[VAL_8]] : memref<2x3xf64>
// OPT: dealloc [[VAL_7]] : memref<2x3xf64>

View File

@ -1,7 +1,6 @@
# RUN: toyc-ch5 %s -emit=ast 2>&1 | FileCheck %s
# User defined generic function that operates solely on
# User defined generic function that operates on unknown shaped arguments.
def multiply_transpose(a, b) {
return a * transpose(b);
}
@ -10,9 +9,11 @@ def main() {
# Define a variable `a` with shape <2, 3>, initialized with the literal value.
# The shape is inferred from the supplied literal.
var a = [[1, 2, 3], [4, 5, 6]];
# b is identical to a, the literal array is implicitely reshaped: defining new
# variables is the way to reshape arrays (element count must match).
# b is identical to a, the literal array is implicitly reshaped: defining new
# variables is the way to reshape arrays (element count in literal must match
# the size of specified shape).
var b<2, 3> = [1, 2, 3, 4, 5, 6];
# This call will specialize `multiply_transpose` with <2, 3> for both
# arguments and deduce a return type of <2, 2> in initialization of `c`.
var c = multiply_transpose(a, b);
@ -30,44 +31,44 @@ def main() {
# CHECK: Module:
# CHECK-NEXT: Function
# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:5:1'
# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1'
# CHECK-NEXT: Params: [a, b]
# CHECK-NEXT: Block {
# CHECK-NEXT: Retur
# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:6:14
# CHECK-NEXT: var: a @{{.*}}ast.toy:6:10
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:6:14
# CHECK-NEXT: var: b @{{.*}}ast.toy:6:24
# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:14
# CHECK-NEXT: var: a @{{.*}}ast.toy:5:10
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:14
# CHECK-NEXT: var: b @{{.*}}ast.toy:5:24
# CHECK-NEXT: ]
# CHECK-NEXT: } // Block
# CHECK-NEXT: Function
# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:9:1'
# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1'
# CHECK-NEXT: Params: []
# CHECK-NEXT: Block {
# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:12:3
# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:12:11
# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3
# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11
# CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3
# CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17
# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:18:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:18:11
# CHECK-NEXT: var: a @{{.*}}ast.toy:18:30
# CHECK-NEXT: var: b @{{.*}}ast.toy:18:33
# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11
# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30
# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33
# CHECK-NEXT: ]
# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:21:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:21:11
# CHECK-NEXT: var: b @{{.*}}ast.toy:21:30
# CHECK-NEXT: var: a @{{.*}}ast.toy:21:33
# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11
# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30
# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33
# CHECK-NEXT: ]
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:24:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:24:11
# CHECK-NEXT: var: b @{{.*}}ast.toy:24:30
# CHECK-NEXT: var: c @{{.*}}ast.toy:24:33
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11
# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30
# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33
# CHECK-NEXT: ]
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:27:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:27:11
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:27:30
# CHECK-NEXT: var: a @{{.*}}ast.toy:27:40
# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:28:3
# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11
# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30
# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40
# CHECK-NEXT: ]
# CHECK-NEXT: var: c @{{.*}}ast.toy:27:44
# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44
# CHECK-NEXT: ]

View File

@ -13,20 +13,19 @@ def main() {
print(d);
}
# CHECK-LABEL: func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array)
# CHECK-NEXT: attributes {toy.generic = true} {
# CHECK-NEXT: %0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array
# CHECK-NEXT: %1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array
# CHECK-NEXT: "toy.return"(%1) : (!toy.array) -> ()
# CHECK-NEXT: }
# CHECK-LABEL: func @main() {
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> !toy.array<2, 3>
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<2, 3>) -> !toy.array<2, 3>
# CHECK-NEXT: %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> !toy.array<6>
# CHECK-NEXT: %3 = "toy.reshape"(%2) : (!toy.array<6>) -> !toy.array<2, 3>
# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
# CHECK-NEXT: "toy.print"(%5) : (!toy.array) -> ()
# CHECK-NEXT: "toy.return"() : () -> ()
# CHECK-LABEL: func @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>)
# CHECK-NEXT: attributes {toy.generic} {
# CHECK-NEXT: [[VAL_2:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = "toy.mul"([[VAL_0]], [[VAL_2]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
# CHECK-NEXT: "toy.return"([[VAL_3]]) : (tensor<*xf64>) -> ()
# CHECK-LABEL: func @main() {
# CHECK-NEXT: [[VAL_4:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
# CHECK-NEXT: [[VAL_5:%.*]] = "toy.reshape"([[VAL_4]]) : (tensor<2x3xf64>) -> tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = "toy.reshape"([[VAL_6]]) : (tensor<6xf64>) -> tensor<2x3xf64>
# CHECK-NEXT: [[VAL_8:%.*]] = "toy.generic_call"([[VAL_5]], [[VAL_7]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
# CHECK-NEXT: [[VAL_9:%.*]] = "toy.generic_call"([[VAL_7]], [[VAL_5]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
# CHECK-NEXT: "toy.print"([[VAL_9]]) : (tensor<*xf64>) -> ()
# CHECK-NEXT: "toy.return"() : () -> ()

View File

@ -1,11 +1,9 @@
// RUN: not toyc-ch5 %s -emit=mlir 2>&1
// This IR is not "valid":
// The following IR is not "valid":
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
// This all round-trip since this is opaque for MLIR.
func @main() {
%0 = "toy.print"() : () -> !toy.array<2, 3>
%0 = "toy.print"() : () -> tensor<2x3xf64>
}

View File

@ -1,16 +0,0 @@
# RUN: toyc-ch5 %s -emit=llvm-ir 2>&1 | FileCheck %s
# User defined generic function that operates on unknown shaped arguments
def multiply_transpose(a, b) {
return a * transpose(b);
}
# CHECK: define void @main() {
# CHECK: %1 = call i8* @malloc(i64 mul (i64 ptrtoint (double* getelementptr (double, double* null, i64 1) to i64), i64 6))
def main() {
var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
var b<2, 3> = [1, 2, 3, 4, 5, 6];
var c = multiply_transpose(a, b);
var d = multiply_transpose(b, a);
print(d);
}

View File

@ -6,9 +6,9 @@ def main() {
}
# CHECK-LABEL: func @main() {
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<1xf64>} : () -> !toy.array<1>
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<1>) -> !toy.array<2, 2>
# CHECK-NEXT: "toy.print"(%1) : (!toy.array<2, 2>) -> ()
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<f64>} : () -> tensor<f64>
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (tensor<f64>) -> tensor<2x2xf64>
# CHECK-NEXT: "toy.print"(%1) : (tensor<2x2xf64>) -> ()
# CHECK-NEXT: "toy.return"() : () -> ()
# CHECK-NEXT: }

View File

@ -0,0 +1,30 @@
// RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s
// Check the result of inlining+shape inference on an input module.
func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
%0 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64>
%1 = "toy.mul"(%arg0, %0) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
"toy.return"(%1) : (tensor<*xf64>) -> ()
}
func @main() {
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
%1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
%2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
%3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64>
%4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
%5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
"toy.print"(%5) : (tensor<*xf64>) -> ()
"toy.return"() : () -> ()
}
// CHECK-NOT: func @multiply_transpose
// CHECK-NOT: tensor<*xf64>
// CHECK-LABEL: func @main() {
// CHECK: [[VAL_0:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
// CHECK: [[VAL_1:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
// CHECK: [[VAL_2:%.*]] = "toy.transpose"([[VAL_0]]) : (tensor<2x3xf64>) -> tensor<3x2xf64>
// CHECK: [[VAL_3:%.*]] = "toy.mul"([[VAL_1]], [[VAL_2]]) : (tensor<2x3xf64>, tensor<3x2xf64>) -> tensor<2x2xf64>
// CHECK: "toy.print"([[VAL_3]]) : (tensor<2x2xf64>) -> ()
// CHECK: "toy.return"() : () -> ()

View File

@ -1,19 +0,0 @@
# RUN: toyc-ch5 %s -emit=mlir 2>&1 | FileCheck %s
# RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT
def transpose_transpose(x) {
return transpose(transpose(x));
}
def main() {
print(transpose_transpose([[1, 2], [3, 4]]));
}
#CHECK-LABEL: func @transpose_transpose
#CHECK: transpose
#CHECK-LABEL: main
#OPT-LABEL: func @transpose_transpose
#OPT-NOT: transpose

View File

@ -1,24 +0,0 @@
# RUN: toyc-ch5 %s -emit=mlir 2>&1 | FileCheck %s
# RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT
# We expect no reshape in this function with optimizations enabled
def foo(a) {
var b<2,1> = a;
var c<2,1> = b;
print(c);
}
def main() {
var a<2, 1> = [1, 2];
foo(a);
}
# without optimizations, match the reshape
#CHECK-LABEL: func @foo
#CHECK: reshape
#CHECK-LABEL: main
# with optimizations, ensure no reshape
#OPT-LABEL: main
#OPT-LABEL: func @foo_2x1
#OPT-NOT: reshape