NFC: Tidy up the implementation of operations in the Toy tutorial

Use header blocks to separate operation implementations, and switch the build methods to be out-of-line when possible.

PiperOrigin-RevId: 278982913
This commit is contained in:
River Riddle 2019-11-06 18:21:04 -08:00 committed by A. Unique TensorFlower
parent 22cfff7043
commit 2fddfcfb14
11 changed files with 251 additions and 208 deletions

View File

@ -75,15 +75,13 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
// using `builder.create<ConstantOp>(...)`.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
OpBuilder<"Builder *builder, OperationState &state, "
"DenseElementsAttr value", [{
build(builder, result, value.getType(), value);
build(builder, state, value.getType(), value);
}]>,
// Build a constant with a given constant floating-point value.
OpBuilder<"Builder *builder, OperationState &result, double value", [{
buildConstantOp(builder, result, value);
}]>
OpBuilder<"Builder *builder, OperationState &state, double value">
];
// Invoke a static verify method to verify this constant operation.
@ -102,10 +100,8 @@ def AddOp : Toy_Op<"add"> {
// Allow building an AddOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildAddOp(b, result, lhs, rhs);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
];
}
def GenericCallOp : Toy_Op<"generic_call"> {
@ -134,11 +130,8 @@ def GenericCallOp : Toy_Op<"generic_call"> {
// Add custom build methods for the generic call operation.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
"StringRef callee, ArrayRef<Value *> arguments", [{
buildGenericCallOp(builder, result, callee, arguments);
}]>
OpBuilder<"Builder *builder, OperationState &state, "
"StringRef callee, ArrayRef<Value *> arguments">
];
}
@ -154,10 +147,8 @@ def MulOp : Toy_Op<"mul"> {
// Allow building a MulOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildMulOp(b, result, lhs, rhs);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
];
}
def PrintOp : Toy_Op<"print"> {
@ -210,7 +201,7 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
// Allow building a ReturnOp with no return operand.
let builders = [OpBuilder<
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
"Builder *b, OperationState &state", [{ build(b, state, llvm::None); }]
>];
// Provide extra utility definitions on the c++ operation class definition.
@ -228,12 +219,10 @@ def TransposeOp : Toy_Op<"transpose"> {
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor);
// Allow building a TransposeOp with from the two input operands.
// Allow building a TransposeOp with from the input operand.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
buildTransposeOp(b, result, input);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *input">
];
// Invoke a static verify method to verify this transpose operation.
let verifier = [{ return ::verify(*this); }];

View File

@ -45,14 +45,17 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
// Toy Operations
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// ConstantOp
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
static void buildConstantOp(mlir::Builder *builder,
mlir::OperationState &result, double value) {
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
double value) {
auto dataType = RankedTensorType::get({}, builder->getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
ConstantOp::build(builder, result, dataType, dataAttribute);
ConstantOp::build(builder, state, dataType, dataAttribute);
}
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
@ -60,7 +63,8 @@ static void buildConstantOp(mlir::Builder *builder,
static mlir::LogicalResult verify(ConstantOp op) {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = op.getResult()->getType().cast<mlir::RankedTensorType>();
auto resultType =
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
if (!resultType)
return success();
@ -86,27 +90,38 @@ static mlir::LogicalResult verify(ConstantOp op) {
return mlir::success();
}
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &result,
mlir::Value *lhs, mlir::Value *rhs) {
result.addTypes(UnrankedTensorType::get(builder->getF64Type()));
result.addOperands({lhs, rhs});
//===----------------------------------------------------------------------===//
// AddOp
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
static void buildGenericCallOp(mlir::Builder *builder,
mlir::OperationState &result, StringRef callee,
ArrayRef<mlir::Value *> arguments) {
//===----------------------------------------------------------------------===//
// GenericCallOp
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
StringRef callee, ArrayRef<mlir::Value *> arguments) {
// Generic call always returns an unranked Tensor initially.
result.addTypes(UnrankedTensorType::get(builder->getF64Type()));
result.addOperands(arguments);
result.addAttribute("callee", builder->getSymbolRefAttr(callee));
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(arguments);
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
}
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &result,
mlir::Value *lhs, mlir::Value *rhs) {
result.addTypes(UnrankedTensorType::get(builder->getF64Type()));
result.addOperands({lhs, rhs});
//===----------------------------------------------------------------------===//
// MulOp
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
//===----------------------------------------------------------------------===//
// ReturnOp
static mlir::LogicalResult verify(ReturnOp op) {
// We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition.
@ -142,10 +157,13 @@ static mlir::LogicalResult verify(ReturnOp op) {
<< results.front() << ")";
}
static void buildTransposeOp(mlir::Builder *builder,
mlir::OperationState &result, mlir::Value *value) {
result.addTypes(UnrankedTensorType::get(builder->getF64Type()));
result.addOperands(value);
//===----------------------------------------------------------------------===//
// TransposeOp
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(value);
}
static mlir::LogicalResult verify(TransposeOp op) {

View File

@ -75,15 +75,13 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
// using `builder.create<ConstantOp>(...)`.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
OpBuilder<"Builder *builder, OperationState &state, "
"DenseElementsAttr value", [{
build(builder, result, value.getType(), value);
build(builder, state, value.getType(), value);
}]>,
// Build a constant with a given constant floating-point value.
OpBuilder<"Builder *builder, OperationState &result, double value", [{
buildConstantOp(builder, result, value);
}]>
OpBuilder<"Builder *builder, OperationState &state, double value">
];
// Invoke a static verify method to verify this constant operation.
@ -102,10 +100,8 @@ def AddOp : Toy_Op<"add", [NoSideEffect]> {
// Allow building an AddOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildAddOp(b, result, lhs, rhs);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
];
}
def GenericCallOp : Toy_Op<"generic_call"> {
@ -134,11 +130,8 @@ def GenericCallOp : Toy_Op<"generic_call"> {
// Add custom build methods for the generic call operation.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
"StringRef callee, ArrayRef<Value *> arguments", [{
buildGenericCallOp(builder, result, callee, arguments);
}]>
OpBuilder<"Builder *builder, OperationState &state, "
"StringRef callee, ArrayRef<Value *> arguments">
];
}
@ -154,10 +147,8 @@ def MulOp : Toy_Op<"mul", [NoSideEffect]> {
// Allow building a MulOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildMulOp(b, result, lhs, rhs);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
];
}
def PrintOp : Toy_Op<"print"> {
@ -213,7 +204,7 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
// Allow building a ReturnOp with no return operand.
let builders = [OpBuilder<
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
"Builder *b, OperationState &state", [{ build(b, state, llvm::None); }]
>];
// Provide extra utility definitions on the c++ operation class definition.
@ -234,12 +225,10 @@ def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
// Enabled registering canonicalization patterns with this operation.
let hasCanonicalizer = 1;
// Allow building a TransposeOp with from the two input operands.
// Allow building a TransposeOp with from the input operand.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
buildTransposeOp(b, result, input);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *input">
];
// Invoke a static verify method to verify this transpose operation.
let verifier = [{ return ::verify(*this); }];

View File

@ -45,11 +45,14 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
// Toy Operations
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// ConstantOp
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
double value) {
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
double value) {
auto dataType = RankedTensorType::get({}, builder->getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
ConstantOp::build(builder, state, dataType, dataAttribute);
@ -60,7 +63,8 @@ static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
static mlir::LogicalResult verify(ConstantOp op) {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
auto resultType =
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
if (!resultType)
return success();
@ -86,27 +90,38 @@ static mlir::LogicalResult verify(ConstantOp op) {
return mlir::success();
}
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
//===----------------------------------------------------------------------===//
// AddOp
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
static void buildGenericCallOp(mlir::Builder *builder,
mlir::OperationState &state, StringRef callee,
ArrayRef<mlir::Value *> arguments) {
//===----------------------------------------------------------------------===//
// GenericCallOp
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
StringRef callee, ArrayRef<mlir::Value *> arguments) {
// Generic call always returns an unranked Tensor initially.
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(arguments);
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
}
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
//===----------------------------------------------------------------------===//
// MulOp
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
//===----------------------------------------------------------------------===//
// ReturnOp
static mlir::LogicalResult verify(ReturnOp op) {
// We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition.
@ -142,8 +157,11 @@ static mlir::LogicalResult verify(ReturnOp op) {
<< results.front() << ")";
}
static void buildTransposeOp(mlir::Builder *builder,
mlir::OperationState &state, mlir::Value *value) {
//===----------------------------------------------------------------------===//
// TransposeOp
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(value);
}

View File

@ -79,15 +79,13 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
// using `builder.create<ConstantOp>(...)`.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
OpBuilder<"Builder *builder, OperationState &state, "
"DenseElementsAttr value", [{
build(builder, result, value.getType(), value);
build(builder, state, value.getType(), value);
}]>,
// Build a constant with a given constant floating-point value.
OpBuilder<"Builder *builder, OperationState &result, double value", [{
buildConstantOp(builder, result, value);
}]>
OpBuilder<"Builder *builder, OperationState &state, double value">
];
// Invoke a static verify method to verify this constant operation.
@ -107,10 +105,8 @@ def AddOp : Toy_Op<"add",
// Allow building an AddOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildAddOp(b, result, lhs, rhs);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
];
}
def CastOp : Toy_Op<"cast",
@ -159,11 +155,8 @@ def GenericCallOp : Toy_Op<"generic_call",
// Add custom build methods for the generic call operation.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
"StringRef callee, ArrayRef<Value *> arguments", [{
buildGenericCallOp(builder, result, callee, arguments);
}]>
OpBuilder<"Builder *builder, OperationState &state, "
"StringRef callee, ArrayRef<Value *> arguments">
];
}
@ -180,10 +173,8 @@ def MulOp : Toy_Op<"mul",
// Allow building a MulOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildMulOp(b, result, lhs, rhs);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
];
}
def PrintOp : Toy_Op<"print"> {
@ -237,7 +228,7 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
// Allow building a ReturnOp with no return operand.
let builders = [OpBuilder<
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
"Builder *b, OperationState &state", [{ build(b, state, llvm::None); }]
>];
// Provide extra utility definitions on the c++ operation class definition.
@ -257,12 +248,10 @@ def TransposeOp : Toy_Op<"transpose",
let results = (outs F64Tensor);
let hasCanonicalizer = 1;
// Allow building a TransposeOp with from the two input operands.
// Allow building a TransposeOp with from the input operand.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
buildTransposeOp(b, result, input);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *input">
];
// Invoke a static verify method to verify this transpose operation.
let verifier = [{ return ::verify(*this); }];

View File

@ -95,26 +95,26 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
// Toy Operations
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// ConstantOp
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
double value) {
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
double value) {
auto dataType = RankedTensorType::get({}, builder->getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
ConstantOp::build(builder, state, dataType, dataAttribute);
}
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
/// in the op definition.
static mlir::LogicalResult verify(ConstantOp op) {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
auto resultType =
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
if (!resultType)
return success();
@ -140,8 +140,11 @@ static mlir::LogicalResult verify(ConstantOp op) {
return mlir::success();
}
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
//===----------------------------------------------------------------------===//
// AddOp
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
@ -150,9 +153,18 @@ static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
/// interface.
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
static void buildGenericCallOp(mlir::Builder *builder,
mlir::OperationState &state, StringRef callee,
ArrayRef<mlir::Value *> arguments) {
//===----------------------------------------------------------------------===//
// CastOp
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
//===----------------------------------------------------------------------===//
// GenericCallOp
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
StringRef callee, ArrayRef<mlir::Value *> arguments) {
// Generic call always returns an unranked Tensor initially.
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(arguments);
@ -169,8 +181,11 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
//===----------------------------------------------------------------------===//
// MulOp
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
@ -179,6 +194,9 @@ static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
/// interface.
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
//===----------------------------------------------------------------------===//
// ReturnOp
static mlir::LogicalResult verify(ReturnOp op) {
// We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition.
@ -214,8 +232,11 @@ static mlir::LogicalResult verify(ReturnOp op) {
<< results.front() << ")";
}
static void buildTransposeOp(mlir::Builder *builder,
mlir::OperationState &state, mlir::Value *value) {
//===----------------------------------------------------------------------===//
// TransposeOp
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(value);
}

View File

@ -79,15 +79,13 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
// using `builder.create<ConstantOp>(...)`.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
OpBuilder<"Builder *builder, OperationState &state, "
"DenseElementsAttr value", [{
build(builder, result, value.getType(), value);
build(builder, state, value.getType(), value);
}]>,
// Build a constant with a given constant floating-point value.
OpBuilder<"Builder *builder, OperationState &result, double value", [{
buildConstantOp(builder, result, value);
}]>
OpBuilder<"Builder *builder, OperationState &state, double value">
];
// Invoke a static verify method to verify this constant operation.
@ -107,10 +105,8 @@ def AddOp : Toy_Op<"add",
// Allow building an AddOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildAddOp(b, result, lhs, rhs);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
];
}
def CastOp : Toy_Op<"cast",
@ -159,11 +155,8 @@ def GenericCallOp : Toy_Op<"generic_call",
// Add custom build methods for the generic call operation.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
"StringRef callee, ArrayRef<Value *> arguments", [{
buildGenericCallOp(builder, result, callee, arguments);
}]>
OpBuilder<"Builder *builder, OperationState &state, "
"StringRef callee, ArrayRef<Value *> arguments">
];
}
@ -180,10 +173,8 @@ def MulOp : Toy_Op<"mul",
// Allow building a MulOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildMulOp(b, result, lhs, rhs);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
];
}
def PrintOp : Toy_Op<"print"> {
@ -238,7 +229,7 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
// Allow building a ReturnOp with no return operand.
let builders = [OpBuilder<
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
"Builder *b, OperationState &state", [{ build(b, state, llvm::None); }]
>];
// Provide extra utility definitions on the c++ operation class definition.
@ -258,12 +249,10 @@ def TransposeOp : Toy_Op<"transpose",
let results = (outs F64Tensor);
let hasCanonicalizer = 1;
// Allow building a TransposeOp with from the two input operands.
// Allow building a TransposeOp with from the input operand.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
buildTransposeOp(b, result, input);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *input">
];
// Invoke a static verify method to verify this transpose operation.
let verifier = [{ return ::verify(*this); }];

View File

@ -95,26 +95,26 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
// Toy Operations
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// ConstantOp
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
double value) {
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
double value) {
auto dataType = RankedTensorType::get({}, builder->getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
ConstantOp::build(builder, state, dataType, dataAttribute);
}
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
/// in the op definition.
static mlir::LogicalResult verify(ConstantOp op) {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
auto resultType =
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
if (!resultType)
return success();
@ -140,8 +140,11 @@ static mlir::LogicalResult verify(ConstantOp op) {
return mlir::success();
}
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
//===----------------------------------------------------------------------===//
// AddOp
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
@ -150,9 +153,18 @@ static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
/// interface.
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
static void buildGenericCallOp(mlir::Builder *builder,
mlir::OperationState &state, StringRef callee,
ArrayRef<mlir::Value *> arguments) {
//===----------------------------------------------------------------------===//
// CastOp
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
//===----------------------------------------------------------------------===//
// GenericCallOp
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
StringRef callee, ArrayRef<mlir::Value *> arguments) {
// Generic call always returns an unranked Tensor initially.
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(arguments);
@ -169,8 +181,11 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
//===----------------------------------------------------------------------===//
// MulOp
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
@ -179,6 +194,9 @@ static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
/// interface.
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
//===----------------------------------------------------------------------===//
// ReturnOp
static mlir::LogicalResult verify(ReturnOp op) {
// We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition.
@ -214,8 +232,11 @@ static mlir::LogicalResult verify(ReturnOp op) {
<< results.front() << ")";
}
static void buildTransposeOp(mlir::Builder *builder,
mlir::OperationState &state, mlir::Value *value) {
//===----------------------------------------------------------------------===//
// TransposeOp
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(value);
}

View File

@ -79,15 +79,13 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
// using `builder.create<ConstantOp>(...)`.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
OpBuilder<"Builder *builder, OperationState &state, "
"DenseElementsAttr value", [{
build(builder, result, value.getType(), value);
build(builder, state, value.getType(), value);
}]>,
// Build a constant with a given constant floating-point value.
OpBuilder<"Builder *builder, OperationState &result, double value", [{
buildConstantOp(builder, result, value);
}]>
OpBuilder<"Builder *builder, OperationState &state, double value">
];
// Invoke a static verify method to verify this constant operation.
@ -107,10 +105,8 @@ def AddOp : Toy_Op<"add",
// Allow building an AddOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildAddOp(b, result, lhs, rhs);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
];
}
def CastOp : Toy_Op<"cast",
@ -159,11 +155,8 @@ def GenericCallOp : Toy_Op<"generic_call",
// Add custom build methods for the generic call operation.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
"StringRef callee, ArrayRef<Value *> arguments", [{
buildGenericCallOp(builder, result, callee, arguments);
}]>
OpBuilder<"Builder *builder, OperationState &state, "
"StringRef callee, ArrayRef<Value *> arguments">
];
}
@ -180,10 +173,8 @@ def MulOp : Toy_Op<"mul",
// Allow building a MulOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildMulOp(b, result, lhs, rhs);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
];
}
def PrintOp : Toy_Op<"print"> {
@ -238,7 +229,7 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
// Allow building a ReturnOp with no return operand.
let builders = [OpBuilder<
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
"Builder *b, OperationState &state", [{ build(b, state, llvm::None); }]
>];
// Provide extra utility definitions on the c++ operation class definition.
@ -258,12 +249,10 @@ def TransposeOp : Toy_Op<"transpose",
let results = (outs F64Tensor);
let hasCanonicalizer = 1;
// Allow building a TransposeOp with from the two input operands.
// Allow building a TransposeOp with from the input operand.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
buildTransposeOp(b, result, input);
}]
>];
OpBuilder<"Builder *b, OperationState &state, Value *input">
];
// Invoke a static verify method to verify this transpose operation.
let verifier = [{ return ::verify(*this); }];

View File

@ -95,26 +95,26 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
// Toy Operations
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// ConstantOp
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
double value) {
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
double value) {
auto dataType = RankedTensorType::get({}, builder->getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
ConstantOp::build(builder, state, dataType, dataAttribute);
}
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
/// in the op definition.
static mlir::LogicalResult verify(ConstantOp op) {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
auto resultType =
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
if (!resultType)
return success();
@ -140,8 +140,11 @@ static mlir::LogicalResult verify(ConstantOp op) {
return mlir::success();
}
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
//===----------------------------------------------------------------------===//
// AddOp
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
@ -150,9 +153,18 @@ static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
/// interface.
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
static void buildGenericCallOp(mlir::Builder *builder,
mlir::OperationState &state, StringRef callee,
ArrayRef<mlir::Value *> arguments) {
//===----------------------------------------------------------------------===//
// CastOp
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
//===----------------------------------------------------------------------===//
// GenericCallOp
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
StringRef callee, ArrayRef<mlir::Value *> arguments) {
// Generic call always returns an unranked Tensor initially.
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(arguments);
@ -169,8 +181,11 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
//===----------------------------------------------------------------------===//
// MulOp
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
@ -179,6 +194,9 @@ static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
/// interface.
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
//===----------------------------------------------------------------------===//
// ReturnOp
static mlir::LogicalResult verify(ReturnOp op) {
// We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition.
@ -214,8 +232,11 @@ static mlir::LogicalResult verify(ReturnOp op) {
<< results.front() << ")";
}
static void buildTransposeOp(mlir::Builder *builder,
mlir::OperationState &state, mlir::Value *value) {
//===----------------------------------------------------------------------===//
// TransposeOp
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
state.addOperands(value);
}

View File

@ -439,7 +439,8 @@ methods. ODS can generate some simple build methods automatically, and in this
case it will generate our first build method for us. For the rest, we define the
[`builders`](../../OpDefinitions.md#custom-builder-methods) field. This field
takes a list of `OpBuilder` objects that take a string corresponding to a list
of c++ parameters, as well as a code block.
of c++ parameters, as well as an optional code block that can be used to specify
the implementation inline.
```tablegen
def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
@ -476,15 +477,13 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
"DenseElementsAttr value", [{
// Call into an autogenerated `build` method.
build(builder, result, value.getType(), value);
}]>,
// Build a constant with a given constant floating-point value. This builder
// invokes a static `buildConstantOp` utility function in a c++ source file
// to keep the tablegen c++ code blocks simple.
OpBuilder<"Builder *builder, OperationState &result, double value", [{
buildConstantOp(builder, result, value);
}]>
// creates a declaration for `ConstantOp::build` with the given parameters.
OpBuilder<"Builder *builder, OperationState &result, double value">
];
}
```