forked from OSchip/llvm-project
NFC: Tidy up the implementation of operations in the Toy tutorial
Use header blocks to separate operation implementations, and switch the build methods to be out-of-line when possible. PiperOrigin-RevId: 278982913
This commit is contained in:
parent
22cfff7043
commit
2fddfcfb14
|
@ -75,15 +75,13 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||||
// using `builder.create<ConstantOp>(...)`.
|
// using `builder.create<ConstantOp>(...)`.
|
||||||
let builders = [
|
let builders = [
|
||||||
// Build a constant with a given constant tensor value.
|
// Build a constant with a given constant tensor value.
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
OpBuilder<"Builder *builder, OperationState &state, "
|
||||||
"DenseElementsAttr value", [{
|
"DenseElementsAttr value", [{
|
||||||
build(builder, result, value.getType(), value);
|
build(builder, state, value.getType(), value);
|
||||||
}]>,
|
}]>,
|
||||||
|
|
||||||
// Build a constant with a given constant floating-point value.
|
// Build a constant with a given constant floating-point value.
|
||||||
OpBuilder<"Builder *builder, OperationState &result, double value", [{
|
OpBuilder<"Builder *builder, OperationState &state, double value">
|
||||||
buildConstantOp(builder, result, value);
|
|
||||||
}]>
|
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this constant operation.
|
// Invoke a static verify method to verify this constant operation.
|
||||||
|
@ -102,10 +100,8 @@ def AddOp : Toy_Op<"add"> {
|
||||||
|
|
||||||
// Allow building an AddOp with from the two input operands.
|
// Allow building an AddOp with from the two input operands.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
|
||||||
buildAddOp(b, result, lhs, rhs);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def GenericCallOp : Toy_Op<"generic_call"> {
|
def GenericCallOp : Toy_Op<"generic_call"> {
|
||||||
|
@ -134,11 +130,8 @@ def GenericCallOp : Toy_Op<"generic_call"> {
|
||||||
|
|
||||||
// Add custom build methods for the generic call operation.
|
// Add custom build methods for the generic call operation.
|
||||||
let builders = [
|
let builders = [
|
||||||
// Build a constant with a given constant tensor value.
|
OpBuilder<"Builder *builder, OperationState &state, "
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
"StringRef callee, ArrayRef<Value *> arguments">
|
||||||
"StringRef callee, ArrayRef<Value *> arguments", [{
|
|
||||||
buildGenericCallOp(builder, result, callee, arguments);
|
|
||||||
}]>
|
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,10 +147,8 @@ def MulOp : Toy_Op<"mul"> {
|
||||||
|
|
||||||
// Allow building a MulOp with from the two input operands.
|
// Allow building a MulOp with from the two input operands.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
|
||||||
buildMulOp(b, result, lhs, rhs);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def PrintOp : Toy_Op<"print"> {
|
def PrintOp : Toy_Op<"print"> {
|
||||||
|
@ -210,7 +201,7 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
|
||||||
|
|
||||||
// Allow building a ReturnOp with no return operand.
|
// Allow building a ReturnOp with no return operand.
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
|
"Builder *b, OperationState &state", [{ build(b, state, llvm::None); }]
|
||||||
>];
|
>];
|
||||||
|
|
||||||
// Provide extra utility definitions on the c++ operation class definition.
|
// Provide extra utility definitions on the c++ operation class definition.
|
||||||
|
@ -228,12 +219,10 @@ def TransposeOp : Toy_Op<"transpose"> {
|
||||||
let arguments = (ins F64Tensor:$input);
|
let arguments = (ins F64Tensor:$input);
|
||||||
let results = (outs F64Tensor);
|
let results = (outs F64Tensor);
|
||||||
|
|
||||||
// Allow building a TransposeOp with from the two input operands.
|
// Allow building a TransposeOp with from the input operand.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *input">
|
||||||
buildTransposeOp(b, result, input);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
|
|
||||||
// Invoke a static verify method to verify this transpose operation.
|
// Invoke a static verify method to verify this transpose operation.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let verifier = [{ return ::verify(*this); }];
|
||||||
|
|
|
@ -45,14 +45,17 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
|
||||||
// Toy Operations
|
// Toy Operations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstantOp
|
||||||
|
|
||||||
/// Build a constant operation.
|
/// Build a constant operation.
|
||||||
/// The builder is passed as an argument, so is the state that this method is
|
/// The builder is passed as an argument, so is the state that this method is
|
||||||
/// expected to fill in order to build the operation.
|
/// expected to fill in order to build the operation.
|
||||||
static void buildConstantOp(mlir::Builder *builder,
|
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
mlir::OperationState &result, double value) {
|
double value) {
|
||||||
auto dataType = RankedTensorType::get({}, builder->getF64Type());
|
auto dataType = RankedTensorType::get({}, builder->getF64Type());
|
||||||
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
||||||
ConstantOp::build(builder, result, dataType, dataAttribute);
|
ConstantOp::build(builder, state, dataType, dataAttribute);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
||||||
|
@ -60,7 +63,8 @@ static void buildConstantOp(mlir::Builder *builder,
|
||||||
static mlir::LogicalResult verify(ConstantOp op) {
|
static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = op.getResult()->getType().cast<mlir::RankedTensorType>();
|
auto resultType =
|
||||||
|
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
|
@ -86,27 +90,38 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &result,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::Value *lhs, mlir::Value *rhs) {
|
// AddOp
|
||||||
result.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
|
||||||
result.addOperands({lhs, rhs});
|
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *lhs, mlir::Value *rhs) {
|
||||||
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildGenericCallOp(mlir::Builder *builder,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::OperationState &result, StringRef callee,
|
// GenericCallOp
|
||||||
ArrayRef<mlir::Value *> arguments) {
|
|
||||||
|
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
StringRef callee, ArrayRef<mlir::Value *> arguments) {
|
||||||
// Generic call always returns an unranked Tensor initially.
|
// Generic call always returns an unranked Tensor initially.
|
||||||
result.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
result.addOperands(arguments);
|
state.addOperands(arguments);
|
||||||
result.addAttribute("callee", builder->getSymbolRefAttr(callee));
|
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &result,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::Value *lhs, mlir::Value *rhs) {
|
// MulOp
|
||||||
result.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
|
||||||
result.addOperands({lhs, rhs});
|
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *lhs, mlir::Value *rhs) {
|
||||||
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReturnOp
|
||||||
|
|
||||||
static mlir::LogicalResult verify(ReturnOp op) {
|
static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
// We know that the parent operation is a function, because of the 'HasParent'
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
// trait attached to the operation definition.
|
// trait attached to the operation definition.
|
||||||
|
@ -142,10 +157,13 @@ static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
<< results.front() << ")";
|
<< results.front() << ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildTransposeOp(mlir::Builder *builder,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::OperationState &result, mlir::Value *value) {
|
// TransposeOp
|
||||||
result.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
|
||||||
result.addOperands(value);
|
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *value) {
|
||||||
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
|
state.addOperands(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::LogicalResult verify(TransposeOp op) {
|
static mlir::LogicalResult verify(TransposeOp op) {
|
||||||
|
|
|
@ -75,15 +75,13 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||||
// using `builder.create<ConstantOp>(...)`.
|
// using `builder.create<ConstantOp>(...)`.
|
||||||
let builders = [
|
let builders = [
|
||||||
// Build a constant with a given constant tensor value.
|
// Build a constant with a given constant tensor value.
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
OpBuilder<"Builder *builder, OperationState &state, "
|
||||||
"DenseElementsAttr value", [{
|
"DenseElementsAttr value", [{
|
||||||
build(builder, result, value.getType(), value);
|
build(builder, state, value.getType(), value);
|
||||||
}]>,
|
}]>,
|
||||||
|
|
||||||
// Build a constant with a given constant floating-point value.
|
// Build a constant with a given constant floating-point value.
|
||||||
OpBuilder<"Builder *builder, OperationState &result, double value", [{
|
OpBuilder<"Builder *builder, OperationState &state, double value">
|
||||||
buildConstantOp(builder, result, value);
|
|
||||||
}]>
|
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this constant operation.
|
// Invoke a static verify method to verify this constant operation.
|
||||||
|
@ -102,10 +100,8 @@ def AddOp : Toy_Op<"add", [NoSideEffect]> {
|
||||||
|
|
||||||
// Allow building an AddOp with from the two input operands.
|
// Allow building an AddOp with from the two input operands.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
|
||||||
buildAddOp(b, result, lhs, rhs);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def GenericCallOp : Toy_Op<"generic_call"> {
|
def GenericCallOp : Toy_Op<"generic_call"> {
|
||||||
|
@ -134,11 +130,8 @@ def GenericCallOp : Toy_Op<"generic_call"> {
|
||||||
|
|
||||||
// Add custom build methods for the generic call operation.
|
// Add custom build methods for the generic call operation.
|
||||||
let builders = [
|
let builders = [
|
||||||
// Build a constant with a given constant tensor value.
|
OpBuilder<"Builder *builder, OperationState &state, "
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
"StringRef callee, ArrayRef<Value *> arguments">
|
||||||
"StringRef callee, ArrayRef<Value *> arguments", [{
|
|
||||||
buildGenericCallOp(builder, result, callee, arguments);
|
|
||||||
}]>
|
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,10 +147,8 @@ def MulOp : Toy_Op<"mul", [NoSideEffect]> {
|
||||||
|
|
||||||
// Allow building a MulOp with from the two input operands.
|
// Allow building a MulOp with from the two input operands.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
|
||||||
buildMulOp(b, result, lhs, rhs);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def PrintOp : Toy_Op<"print"> {
|
def PrintOp : Toy_Op<"print"> {
|
||||||
|
@ -213,7 +204,7 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
|
||||||
|
|
||||||
// Allow building a ReturnOp with no return operand.
|
// Allow building a ReturnOp with no return operand.
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
|
"Builder *b, OperationState &state", [{ build(b, state, llvm::None); }]
|
||||||
>];
|
>];
|
||||||
|
|
||||||
// Provide extra utility definitions on the c++ operation class definition.
|
// Provide extra utility definitions on the c++ operation class definition.
|
||||||
|
@ -234,12 +225,10 @@ def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
|
||||||
// Enabled registering canonicalization patterns with this operation.
|
// Enabled registering canonicalization patterns with this operation.
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
|
||||||
// Allow building a TransposeOp with from the two input operands.
|
// Allow building a TransposeOp with from the input operand.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *input">
|
||||||
buildTransposeOp(b, result, input);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
|
|
||||||
// Invoke a static verify method to verify this transpose operation.
|
// Invoke a static verify method to verify this transpose operation.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let verifier = [{ return ::verify(*this); }];
|
||||||
|
|
|
@ -45,11 +45,14 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
|
||||||
// Toy Operations
|
// Toy Operations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstantOp
|
||||||
|
|
||||||
/// Build a constant operation.
|
/// Build a constant operation.
|
||||||
/// The builder is passed as an argument, so is the state that this method is
|
/// The builder is passed as an argument, so is the state that this method is
|
||||||
/// expected to fill in order to build the operation.
|
/// expected to fill in order to build the operation.
|
||||||
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
|
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
double value) {
|
double value) {
|
||||||
auto dataType = RankedTensorType::get({}, builder->getF64Type());
|
auto dataType = RankedTensorType::get({}, builder->getF64Type());
|
||||||
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
||||||
ConstantOp::build(builder, state, dataType, dataAttribute);
|
ConstantOp::build(builder, state, dataType, dataAttribute);
|
||||||
|
@ -60,7 +63,8 @@ static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
static mlir::LogicalResult verify(ConstantOp op) {
|
static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
|
auto resultType =
|
||||||
|
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
|
@ -86,27 +90,38 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::Value *lhs, mlir::Value *rhs) {
|
// AddOp
|
||||||
|
|
||||||
|
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *lhs, mlir::Value *rhs) {
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands({lhs, rhs});
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildGenericCallOp(mlir::Builder *builder,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::OperationState &state, StringRef callee,
|
// GenericCallOp
|
||||||
ArrayRef<mlir::Value *> arguments) {
|
|
||||||
|
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
StringRef callee, ArrayRef<mlir::Value *> arguments) {
|
||||||
// Generic call always returns an unranked Tensor initially.
|
// Generic call always returns an unranked Tensor initially.
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands(arguments);
|
state.addOperands(arguments);
|
||||||
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
|
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::Value *lhs, mlir::Value *rhs) {
|
// MulOp
|
||||||
|
|
||||||
|
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *lhs, mlir::Value *rhs) {
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands({lhs, rhs});
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReturnOp
|
||||||
|
|
||||||
static mlir::LogicalResult verify(ReturnOp op) {
|
static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
// We know that the parent operation is a function, because of the 'HasParent'
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
// trait attached to the operation definition.
|
// trait attached to the operation definition.
|
||||||
|
@ -142,8 +157,11 @@ static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
<< results.front() << ")";
|
<< results.front() << ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildTransposeOp(mlir::Builder *builder,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::OperationState &state, mlir::Value *value) {
|
// TransposeOp
|
||||||
|
|
||||||
|
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *value) {
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands(value);
|
state.addOperands(value);
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,15 +79,13 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||||
// using `builder.create<ConstantOp>(...)`.
|
// using `builder.create<ConstantOp>(...)`.
|
||||||
let builders = [
|
let builders = [
|
||||||
// Build a constant with a given constant tensor value.
|
// Build a constant with a given constant tensor value.
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
OpBuilder<"Builder *builder, OperationState &state, "
|
||||||
"DenseElementsAttr value", [{
|
"DenseElementsAttr value", [{
|
||||||
build(builder, result, value.getType(), value);
|
build(builder, state, value.getType(), value);
|
||||||
}]>,
|
}]>,
|
||||||
|
|
||||||
// Build a constant with a given constant floating-point value.
|
// Build a constant with a given constant floating-point value.
|
||||||
OpBuilder<"Builder *builder, OperationState &result, double value", [{
|
OpBuilder<"Builder *builder, OperationState &state, double value">
|
||||||
buildConstantOp(builder, result, value);
|
|
||||||
}]>
|
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this constant operation.
|
// Invoke a static verify method to verify this constant operation.
|
||||||
|
@ -107,10 +105,8 @@ def AddOp : Toy_Op<"add",
|
||||||
|
|
||||||
// Allow building an AddOp with from the two input operands.
|
// Allow building an AddOp with from the two input operands.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
|
||||||
buildAddOp(b, result, lhs, rhs);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def CastOp : Toy_Op<"cast",
|
def CastOp : Toy_Op<"cast",
|
||||||
|
@ -159,11 +155,8 @@ def GenericCallOp : Toy_Op<"generic_call",
|
||||||
|
|
||||||
// Add custom build methods for the generic call operation.
|
// Add custom build methods for the generic call operation.
|
||||||
let builders = [
|
let builders = [
|
||||||
// Build a constant with a given constant tensor value.
|
OpBuilder<"Builder *builder, OperationState &state, "
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
"StringRef callee, ArrayRef<Value *> arguments">
|
||||||
"StringRef callee, ArrayRef<Value *> arguments", [{
|
|
||||||
buildGenericCallOp(builder, result, callee, arguments);
|
|
||||||
}]>
|
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,10 +173,8 @@ def MulOp : Toy_Op<"mul",
|
||||||
|
|
||||||
// Allow building a MulOp with from the two input operands.
|
// Allow building a MulOp with from the two input operands.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
|
||||||
buildMulOp(b, result, lhs, rhs);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def PrintOp : Toy_Op<"print"> {
|
def PrintOp : Toy_Op<"print"> {
|
||||||
|
@ -237,7 +228,7 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
|
||||||
|
|
||||||
// Allow building a ReturnOp with no return operand.
|
// Allow building a ReturnOp with no return operand.
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
|
"Builder *b, OperationState &state", [{ build(b, state, llvm::None); }]
|
||||||
>];
|
>];
|
||||||
|
|
||||||
// Provide extra utility definitions on the c++ operation class definition.
|
// Provide extra utility definitions on the c++ operation class definition.
|
||||||
|
@ -257,12 +248,10 @@ def TransposeOp : Toy_Op<"transpose",
|
||||||
let results = (outs F64Tensor);
|
let results = (outs F64Tensor);
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
|
||||||
// Allow building a TransposeOp with from the two input operands.
|
// Allow building a TransposeOp with from the input operand.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *input">
|
||||||
buildTransposeOp(b, result, input);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
|
|
||||||
// Invoke a static verify method to verify this transpose operation.
|
// Invoke a static verify method to verify this transpose operation.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let verifier = [{ return ::verify(*this); }];
|
||||||
|
|
|
@ -95,26 +95,26 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
|
||||||
// Toy Operations
|
// Toy Operations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstantOp
|
||||||
|
|
||||||
/// Build a constant operation.
|
/// Build a constant operation.
|
||||||
/// The builder is passed as an argument, so is the state that this method is
|
/// The builder is passed as an argument, so is the state that this method is
|
||||||
/// expected to fill in order to build the operation.
|
/// expected to fill in order to build the operation.
|
||||||
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
|
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
double value) {
|
double value) {
|
||||||
auto dataType = RankedTensorType::get({}, builder->getF64Type());
|
auto dataType = RankedTensorType::get({}, builder->getF64Type());
|
||||||
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
||||||
ConstantOp::build(builder, state, dataType, dataAttribute);
|
ConstantOp::build(builder, state, dataType, dataAttribute);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer the output shape of the CastOp, this is required by the shape
|
|
||||||
/// inference interface.
|
|
||||||
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
|
||||||
|
|
||||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
||||||
/// in the op definition.
|
/// in the op definition.
|
||||||
static mlir::LogicalResult verify(ConstantOp op) {
|
static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
|
auto resultType =
|
||||||
|
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
|
@ -140,8 +140,11 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::Value *lhs, mlir::Value *rhs) {
|
// AddOp
|
||||||
|
|
||||||
|
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *lhs, mlir::Value *rhs) {
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands({lhs, rhs});
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
@ -150,9 +153,18 @@ static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
/// interface.
|
/// interface.
|
||||||
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||||
|
|
||||||
static void buildGenericCallOp(mlir::Builder *builder,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::OperationState &state, StringRef callee,
|
// CastOp
|
||||||
ArrayRef<mlir::Value *> arguments) {
|
|
||||||
|
/// Infer the output shape of the CastOp, this is required by the shape
|
||||||
|
/// inference interface.
|
||||||
|
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GenericCallOp
|
||||||
|
|
||||||
|
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
StringRef callee, ArrayRef<mlir::Value *> arguments) {
|
||||||
// Generic call always returns an unranked Tensor initially.
|
// Generic call always returns an unranked Tensor initially.
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands(arguments);
|
state.addOperands(arguments);
|
||||||
|
@ -169,8 +181,11 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
|
||||||
/// call interface.
|
/// call interface.
|
||||||
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
|
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
|
||||||
|
|
||||||
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::Value *lhs, mlir::Value *rhs) {
|
// MulOp
|
||||||
|
|
||||||
|
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *lhs, mlir::Value *rhs) {
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands({lhs, rhs});
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
@ -179,6 +194,9 @@ static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
/// interface.
|
/// interface.
|
||||||
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReturnOp
|
||||||
|
|
||||||
static mlir::LogicalResult verify(ReturnOp op) {
|
static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
// We know that the parent operation is a function, because of the 'HasParent'
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
// trait attached to the operation definition.
|
// trait attached to the operation definition.
|
||||||
|
@ -214,8 +232,11 @@ static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
<< results.front() << ")";
|
<< results.front() << ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildTransposeOp(mlir::Builder *builder,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::OperationState &state, mlir::Value *value) {
|
// TransposeOp
|
||||||
|
|
||||||
|
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *value) {
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands(value);
|
state.addOperands(value);
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,15 +79,13 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||||
// using `builder.create<ConstantOp>(...)`.
|
// using `builder.create<ConstantOp>(...)`.
|
||||||
let builders = [
|
let builders = [
|
||||||
// Build a constant with a given constant tensor value.
|
// Build a constant with a given constant tensor value.
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
OpBuilder<"Builder *builder, OperationState &state, "
|
||||||
"DenseElementsAttr value", [{
|
"DenseElementsAttr value", [{
|
||||||
build(builder, result, value.getType(), value);
|
build(builder, state, value.getType(), value);
|
||||||
}]>,
|
}]>,
|
||||||
|
|
||||||
// Build a constant with a given constant floating-point value.
|
// Build a constant with a given constant floating-point value.
|
||||||
OpBuilder<"Builder *builder, OperationState &result, double value", [{
|
OpBuilder<"Builder *builder, OperationState &state, double value">
|
||||||
buildConstantOp(builder, result, value);
|
|
||||||
}]>
|
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this constant operation.
|
// Invoke a static verify method to verify this constant operation.
|
||||||
|
@ -107,10 +105,8 @@ def AddOp : Toy_Op<"add",
|
||||||
|
|
||||||
// Allow building an AddOp with from the two input operands.
|
// Allow building an AddOp with from the two input operands.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
|
||||||
buildAddOp(b, result, lhs, rhs);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def CastOp : Toy_Op<"cast",
|
def CastOp : Toy_Op<"cast",
|
||||||
|
@ -159,11 +155,8 @@ def GenericCallOp : Toy_Op<"generic_call",
|
||||||
|
|
||||||
// Add custom build methods for the generic call operation.
|
// Add custom build methods for the generic call operation.
|
||||||
let builders = [
|
let builders = [
|
||||||
// Build a constant with a given constant tensor value.
|
OpBuilder<"Builder *builder, OperationState &state, "
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
"StringRef callee, ArrayRef<Value *> arguments">
|
||||||
"StringRef callee, ArrayRef<Value *> arguments", [{
|
|
||||||
buildGenericCallOp(builder, result, callee, arguments);
|
|
||||||
}]>
|
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,10 +173,8 @@ def MulOp : Toy_Op<"mul",
|
||||||
|
|
||||||
// Allow building a MulOp with from the two input operands.
|
// Allow building a MulOp with from the two input operands.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
|
||||||
buildMulOp(b, result, lhs, rhs);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def PrintOp : Toy_Op<"print"> {
|
def PrintOp : Toy_Op<"print"> {
|
||||||
|
@ -238,7 +229,7 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
|
||||||
|
|
||||||
// Allow building a ReturnOp with no return operand.
|
// Allow building a ReturnOp with no return operand.
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
|
"Builder *b, OperationState &state", [{ build(b, state, llvm::None); }]
|
||||||
>];
|
>];
|
||||||
|
|
||||||
// Provide extra utility definitions on the c++ operation class definition.
|
// Provide extra utility definitions on the c++ operation class definition.
|
||||||
|
@ -258,12 +249,10 @@ def TransposeOp : Toy_Op<"transpose",
|
||||||
let results = (outs F64Tensor);
|
let results = (outs F64Tensor);
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
|
||||||
// Allow building a TransposeOp with from the two input operands.
|
// Allow building a TransposeOp with from the input operand.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *input">
|
||||||
buildTransposeOp(b, result, input);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
|
|
||||||
// Invoke a static verify method to verify this transpose operation.
|
// Invoke a static verify method to verify this transpose operation.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let verifier = [{ return ::verify(*this); }];
|
||||||
|
|
|
@ -95,26 +95,26 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
|
||||||
// Toy Operations
|
// Toy Operations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstantOp
|
||||||
|
|
||||||
/// Build a constant operation.
|
/// Build a constant operation.
|
||||||
/// The builder is passed as an argument, so is the state that this method is
|
/// The builder is passed as an argument, so is the state that this method is
|
||||||
/// expected to fill in order to build the operation.
|
/// expected to fill in order to build the operation.
|
||||||
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
|
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
double value) {
|
double value) {
|
||||||
auto dataType = RankedTensorType::get({}, builder->getF64Type());
|
auto dataType = RankedTensorType::get({}, builder->getF64Type());
|
||||||
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
||||||
ConstantOp::build(builder, state, dataType, dataAttribute);
|
ConstantOp::build(builder, state, dataType, dataAttribute);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer the output shape of the CastOp, this is required by the shape
|
|
||||||
/// inference interface.
|
|
||||||
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
|
||||||
|
|
||||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
||||||
/// in the op definition.
|
/// in the op definition.
|
||||||
static mlir::LogicalResult verify(ConstantOp op) {
|
static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
|
auto resultType =
|
||||||
|
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
|
@ -140,8 +140,11 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::Value *lhs, mlir::Value *rhs) {
|
// AddOp
|
||||||
|
|
||||||
|
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *lhs, mlir::Value *rhs) {
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands({lhs, rhs});
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
@ -150,9 +153,18 @@ static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
/// interface.
|
/// interface.
|
||||||
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||||
|
|
||||||
static void buildGenericCallOp(mlir::Builder *builder,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::OperationState &state, StringRef callee,
|
// CastOp
|
||||||
ArrayRef<mlir::Value *> arguments) {
|
|
||||||
|
/// Infer the output shape of the CastOp, this is required by the shape
|
||||||
|
/// inference interface.
|
||||||
|
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GenericCallOp
|
||||||
|
|
||||||
|
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
StringRef callee, ArrayRef<mlir::Value *> arguments) {
|
||||||
// Generic call always returns an unranked Tensor initially.
|
// Generic call always returns an unranked Tensor initially.
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands(arguments);
|
state.addOperands(arguments);
|
||||||
|
@ -169,8 +181,11 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
|
||||||
/// call interface.
|
/// call interface.
|
||||||
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
|
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
|
||||||
|
|
||||||
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::Value *lhs, mlir::Value *rhs) {
|
// MulOp
|
||||||
|
|
||||||
|
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *lhs, mlir::Value *rhs) {
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands({lhs, rhs});
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
@ -179,6 +194,9 @@ static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
/// interface.
|
/// interface.
|
||||||
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReturnOp
|
||||||
|
|
||||||
static mlir::LogicalResult verify(ReturnOp op) {
|
static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
// We know that the parent operation is a function, because of the 'HasParent'
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
// trait attached to the operation definition.
|
// trait attached to the operation definition.
|
||||||
|
@ -214,8 +232,11 @@ static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
<< results.front() << ")";
|
<< results.front() << ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildTransposeOp(mlir::Builder *builder,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::OperationState &state, mlir::Value *value) {
|
// TransposeOp
|
||||||
|
|
||||||
|
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *value) {
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands(value);
|
state.addOperands(value);
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,15 +79,13 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||||
// using `builder.create<ConstantOp>(...)`.
|
// using `builder.create<ConstantOp>(...)`.
|
||||||
let builders = [
|
let builders = [
|
||||||
// Build a constant with a given constant tensor value.
|
// Build a constant with a given constant tensor value.
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
OpBuilder<"Builder *builder, OperationState &state, "
|
||||||
"DenseElementsAttr value", [{
|
"DenseElementsAttr value", [{
|
||||||
build(builder, result, value.getType(), value);
|
build(builder, state, value.getType(), value);
|
||||||
}]>,
|
}]>,
|
||||||
|
|
||||||
// Build a constant with a given constant floating-point value.
|
// Build a constant with a given constant floating-point value.
|
||||||
OpBuilder<"Builder *builder, OperationState &result, double value", [{
|
OpBuilder<"Builder *builder, OperationState &state, double value">
|
||||||
buildConstantOp(builder, result, value);
|
|
||||||
}]>
|
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this constant operation.
|
// Invoke a static verify method to verify this constant operation.
|
||||||
|
@ -107,10 +105,8 @@ def AddOp : Toy_Op<"add",
|
||||||
|
|
||||||
// Allow building an AddOp with from the two input operands.
|
// Allow building an AddOp with from the two input operands.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
|
||||||
buildAddOp(b, result, lhs, rhs);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def CastOp : Toy_Op<"cast",
|
def CastOp : Toy_Op<"cast",
|
||||||
|
@ -159,11 +155,8 @@ def GenericCallOp : Toy_Op<"generic_call",
|
||||||
|
|
||||||
// Add custom build methods for the generic call operation.
|
// Add custom build methods for the generic call operation.
|
||||||
let builders = [
|
let builders = [
|
||||||
// Build a constant with a given constant tensor value.
|
OpBuilder<"Builder *builder, OperationState &state, "
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
"StringRef callee, ArrayRef<Value *> arguments">
|
||||||
"StringRef callee, ArrayRef<Value *> arguments", [{
|
|
||||||
buildGenericCallOp(builder, result, callee, arguments);
|
|
||||||
}]>
|
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,10 +173,8 @@ def MulOp : Toy_Op<"mul",
|
||||||
|
|
||||||
// Allow building a MulOp with from the two input operands.
|
// Allow building a MulOp with from the two input operands.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *lhs, Value *rhs">
|
||||||
buildMulOp(b, result, lhs, rhs);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def PrintOp : Toy_Op<"print"> {
|
def PrintOp : Toy_Op<"print"> {
|
||||||
|
@ -238,7 +229,7 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
|
||||||
|
|
||||||
// Allow building a ReturnOp with no return operand.
|
// Allow building a ReturnOp with no return operand.
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
|
"Builder *b, OperationState &state", [{ build(b, state, llvm::None); }]
|
||||||
>];
|
>];
|
||||||
|
|
||||||
// Provide extra utility definitions on the c++ operation class definition.
|
// Provide extra utility definitions on the c++ operation class definition.
|
||||||
|
@ -258,12 +249,10 @@ def TransposeOp : Toy_Op<"transpose",
|
||||||
let results = (outs F64Tensor);
|
let results = (outs F64Tensor);
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
|
||||||
// Allow building a TransposeOp with from the two input operands.
|
// Allow building a TransposeOp with from the input operand.
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
|
OpBuilder<"Builder *b, OperationState &state, Value *input">
|
||||||
buildTransposeOp(b, result, input);
|
];
|
||||||
}]
|
|
||||||
>];
|
|
||||||
|
|
||||||
// Invoke a static verify method to verify this transpose operation.
|
// Invoke a static verify method to verify this transpose operation.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let verifier = [{ return ::verify(*this); }];
|
||||||
|
|
|
@ -95,26 +95,26 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
|
||||||
// Toy Operations
|
// Toy Operations
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstantOp
|
||||||
|
|
||||||
/// Build a constant operation.
|
/// Build a constant operation.
|
||||||
/// The builder is passed as an argument, so is the state that this method is
|
/// The builder is passed as an argument, so is the state that this method is
|
||||||
/// expected to fill in order to build the operation.
|
/// expected to fill in order to build the operation.
|
||||||
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
|
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
double value) {
|
double value) {
|
||||||
auto dataType = RankedTensorType::get({}, builder->getF64Type());
|
auto dataType = RankedTensorType::get({}, builder->getF64Type());
|
||||||
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
||||||
ConstantOp::build(builder, state, dataType, dataAttribute);
|
ConstantOp::build(builder, state, dataType, dataAttribute);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer the output shape of the CastOp, this is required by the shape
|
|
||||||
/// inference interface.
|
|
||||||
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
|
||||||
|
|
||||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
||||||
/// in the op definition.
|
/// in the op definition.
|
||||||
static mlir::LogicalResult verify(ConstantOp op) {
|
static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
|
auto resultType =
|
||||||
|
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
|
@ -140,8 +140,11 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::Value *lhs, mlir::Value *rhs) {
|
// AddOp
|
||||||
|
|
||||||
|
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *lhs, mlir::Value *rhs) {
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands({lhs, rhs});
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
@ -150,9 +153,18 @@ static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
/// interface.
|
/// interface.
|
||||||
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||||
|
|
||||||
static void buildGenericCallOp(mlir::Builder *builder,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::OperationState &state, StringRef callee,
|
// CastOp
|
||||||
ArrayRef<mlir::Value *> arguments) {
|
|
||||||
|
/// Infer the output shape of the CastOp, this is required by the shape
|
||||||
|
/// inference interface.
|
||||||
|
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GenericCallOp
|
||||||
|
|
||||||
|
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
StringRef callee, ArrayRef<mlir::Value *> arguments) {
|
||||||
// Generic call always returns an unranked Tensor initially.
|
// Generic call always returns an unranked Tensor initially.
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands(arguments);
|
state.addOperands(arguments);
|
||||||
|
@ -169,8 +181,11 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
|
||||||
/// call interface.
|
/// call interface.
|
||||||
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
|
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
|
||||||
|
|
||||||
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::Value *lhs, mlir::Value *rhs) {
|
// MulOp
|
||||||
|
|
||||||
|
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *lhs, mlir::Value *rhs) {
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands({lhs, rhs});
|
state.addOperands({lhs, rhs});
|
||||||
}
|
}
|
||||||
|
@ -179,6 +194,9 @@ static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
/// interface.
|
/// interface.
|
||||||
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ReturnOp
|
||||||
|
|
||||||
static mlir::LogicalResult verify(ReturnOp op) {
|
static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
// We know that the parent operation is a function, because of the 'HasParent'
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
// trait attached to the operation definition.
|
// trait attached to the operation definition.
|
||||||
|
@ -214,8 +232,11 @@ static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
<< results.front() << ")";
|
<< results.front() << ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildTransposeOp(mlir::Builder *builder,
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::OperationState &state, mlir::Value *value) {
|
// TransposeOp
|
||||||
|
|
||||||
|
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||||
|
mlir::Value *value) {
|
||||||
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
state.addTypes(UnrankedTensorType::get(builder->getF64Type()));
|
||||||
state.addOperands(value);
|
state.addOperands(value);
|
||||||
}
|
}
|
||||||
|
|
|
@ -439,7 +439,8 @@ methods. ODS can generate some simple build methods automatically, and in this
|
||||||
case it will generate our first build method for us. For the rest, we define the
|
case it will generate our first build method for us. For the rest, we define the
|
||||||
[`builders`](../../OpDefinitions.md#custom-builder-methods) field. This field
|
[`builders`](../../OpDefinitions.md#custom-builder-methods) field. This field
|
||||||
takes a list of `OpBuilder` objects that take a string corresponding to a list
|
takes a list of `OpBuilder` objects that take a string corresponding to a list
|
||||||
of c++ parameters, as well as a code block.
|
of c++ parameters, as well as an optional code block that can be used to specify
|
||||||
|
the implementation inline.
|
||||||
|
|
||||||
```tablegen
|
```tablegen
|
||||||
def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||||
|
@ -476,15 +477,13 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||||
// Build a constant with a given constant tensor value.
|
// Build a constant with a given constant tensor value.
|
||||||
OpBuilder<"Builder *builder, OperationState &result, "
|
OpBuilder<"Builder *builder, OperationState &result, "
|
||||||
"DenseElementsAttr value", [{
|
"DenseElementsAttr value", [{
|
||||||
|
// Call into an autogenerated `build` method.
|
||||||
build(builder, result, value.getType(), value);
|
build(builder, result, value.getType(), value);
|
||||||
}]>,
|
}]>,
|
||||||
|
|
||||||
// Build a constant with a given constant floating-point value. This builder
|
// Build a constant with a given constant floating-point value. This builder
|
||||||
// invokes a static `buildConstantOp` utility function in a c++ source file
|
// creates a declaration for `ConstantOp::build` with the given parameters.
|
||||||
// to keep the tablegen c++ code blocks simple.
|
OpBuilder<"Builder *builder, OperationState &result, double value">
|
||||||
OpBuilder<"Builder *builder, OperationState &result, double value", [{
|
|
||||||
buildConstantOp(builder, result, value);
|
|
||||||
}]>
|
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
Loading…
Reference in New Issue