forked from OSchip/llvm-project
[mlir][NFC] Update Toy operations to use `hasVerifier` instead of `verifier`
The verifier field is deprecated, and slated for removal. Differential Revision: https://reviews.llvm.org/D118816
This commit is contained in:
parent
42e5f1d97b
commit
4e190c58de
|
@ -280,7 +280,7 @@ class ConstantOp : public mlir::Op<
|
||||||
/// traits provide. Here we will ensure that the specific invariants of the
|
/// traits provide. Here we will ensure that the specific invariants of the
|
||||||
/// constant operation are upheld, for example the result type must be
|
/// constant operation are upheld, for example the result type must be
|
||||||
/// of TensorType and matches the type of the constant `value`.
|
/// of TensorType and matches the type of the constant `value`.
|
||||||
LogicalResult verify();
|
LogicalResult verifyInvariants();
|
||||||
|
|
||||||
/// Provide an interface to build this operation from a set of input values.
|
/// Provide an interface to build this operation from a set of input values.
|
||||||
/// This interface is used by the `builder` classes to allow for easily
|
/// This interface is used by the `builder` classes to allow for easily
|
||||||
|
@ -495,11 +495,12 @@ def ConstantOp : Toy_Op<"constant"> {
|
||||||
// F64Tensor corresponds to a 64-bit floating-point TensorType.
|
// F64Tensor corresponds to a 64-bit floating-point TensorType.
|
||||||
let results = (outs F64Tensor);
|
let results = (outs F64Tensor);
|
||||||
|
|
||||||
// Add additional verification logic to the constant operation. Here we invoke
|
// Add additional verification logic to the constant operation. Setting this bit
|
||||||
// a static `verify` method in a C++ source file. This codeblock is executed
|
// to `1` will generate a `::mlir::LogicalResult verify()` declaration on the
|
||||||
// inside of ConstantOp::verify, so we can use `this` to refer to the current
|
// operation class that is called after ODS constructs have been verified, for
|
||||||
// operation instance.
|
// example the types of arguments and results. We implement additional verification
|
||||||
let verifier = [{ return ::verify(*this); }];
|
// in the definition of this `verify` method in the C++ source file.
|
||||||
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -76,8 +76,8 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||||
OpBuilder<(ins "double":$value)>
|
OpBuilder<(ins "double":$value)>
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this constant operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def AddOp : Toy_Op<"add"> {
|
def AddOp : Toy_Op<"add"> {
|
||||||
|
@ -224,7 +224,7 @@ def ReturnOp : Toy_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this return operation.
|
// Invoke a static verify method to verify this return operation.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TransposeOp : Toy_Op<"transpose"> {
|
def TransposeOp : Toy_Op<"transpose"> {
|
||||||
|
@ -243,7 +243,7 @@ def TransposeOp : Toy_Op<"transpose"> {
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this transpose operation.
|
// Invoke a static verify method to verify this transpose operation.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // TOY_OPS
|
#endif // TOY_OPS
|
||||||
|
|
|
@ -126,21 +126,20 @@ static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
|
||||||
printer << op.value();
|
printer << op.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
/// Verifier for the constant operation. This corresponds to the
|
||||||
/// in the op definition.
|
/// `let hasVerifier = 1` in the op definition.
|
||||||
static mlir::LogicalResult verify(ConstantOp op) {
|
mlir::LogicalResult ConstantOp::verify() {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that the rank of the attribute type matches the rank of the constant
|
// Check that the rank of the attribute type matches the rank of the constant
|
||||||
// result type.
|
// result type.
|
||||||
auto attrType = op.value().getType().cast<mlir::TensorType>();
|
auto attrType = value().getType().cast<mlir::TensorType>();
|
||||||
if (attrType.getRank() != resultType.getRank()) {
|
if (attrType.getRank() != resultType.getRank()) {
|
||||||
return op.emitOpError(
|
return emitOpError("return type must match the one of the attached value "
|
||||||
"return type must match the one of the attached value "
|
|
||||||
"attribute: ")
|
"attribute: ")
|
||||||
<< attrType.getRank() << " != " << resultType.getRank();
|
<< attrType.getRank() << " != " << resultType.getRank();
|
||||||
}
|
}
|
||||||
|
@ -148,7 +147,7 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
// Check that each of the dimensions match between the two types.
|
// Check that each of the dimensions match between the two types.
|
||||||
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
||||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||||
return op.emitOpError(
|
return emitOpError(
|
||||||
"return type shape mismatches its attribute at dimension ")
|
"return type shape mismatches its attribute at dimension ")
|
||||||
<< dim << ": " << attrType.getShape()[dim]
|
<< dim << ": " << attrType.getShape()[dim]
|
||||||
<< " != " << resultType.getShape()[dim];
|
<< " != " << resultType.getShape()[dim];
|
||||||
|
@ -190,28 +189,27 @@ void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReturnOp
|
// ReturnOp
|
||||||
|
|
||||||
static mlir::LogicalResult verify(ReturnOp op) {
|
mlir::LogicalResult ReturnOp::verify() {
|
||||||
// We know that the parent operation is a function, because of the 'HasParent'
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
// trait attached to the operation definition.
|
// trait attached to the operation definition.
|
||||||
auto function = cast<FuncOp>(op->getParentOp());
|
auto function = cast<FuncOp>((*this)->getParentOp());
|
||||||
|
|
||||||
/// ReturnOps can only have a single optional operand.
|
/// ReturnOps can only have a single optional operand.
|
||||||
if (op.getNumOperands() > 1)
|
if (getNumOperands() > 1)
|
||||||
return op.emitOpError() << "expects at most 1 return operand";
|
return emitOpError() << "expects at most 1 return operand";
|
||||||
|
|
||||||
// The operand number and types must match the function signature.
|
// The operand number and types must match the function signature.
|
||||||
const auto &results = function.getType().getResults();
|
const auto &results = function.getType().getResults();
|
||||||
if (op.getNumOperands() != results.size())
|
if (getNumOperands() != results.size())
|
||||||
return op.emitOpError()
|
return emitOpError() << "does not return the same number of values ("
|
||||||
<< "does not return the same number of values ("
|
<< getNumOperands() << ") as the enclosing function ("
|
||||||
<< op.getNumOperands() << ") as the enclosing function ("
|
|
||||||
<< results.size() << ")";
|
<< results.size() << ")";
|
||||||
|
|
||||||
// If the operation does not have an input, we are done.
|
// If the operation does not have an input, we are done.
|
||||||
if (!op.hasOperand())
|
if (!hasOperand())
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
auto inputType = *op.operand_type_begin();
|
auto inputType = *operand_type_begin();
|
||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
|
@ -219,9 +217,9 @@ static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
resultType.isa<mlir::UnrankedTensorType>())
|
resultType.isa<mlir::UnrankedTensorType>())
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
return op.emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
<< ") doesn't match function result type ("
|
<< ") doesn't match function result type (" << resultType
|
||||||
<< resultType << ")";
|
<< ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -233,16 +231,16 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||||
state.addOperands(value);
|
state.addOperands(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::LogicalResult verify(TransposeOp op) {
|
mlir::LogicalResult TransposeOp::verify() {
|
||||||
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
||||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
auto resultType = getType().dyn_cast<RankedTensorType>();
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
auto inputShape = inputType.getShape();
|
auto inputShape = inputType.getShape();
|
||||||
if (!std::equal(inputShape.begin(), inputShape.end(),
|
if (!std::equal(inputShape.begin(), inputShape.end(),
|
||||||
resultType.getShape().rbegin())) {
|
resultType.getShape().rbegin())) {
|
||||||
return op.emitError()
|
return emitError()
|
||||||
<< "expected result shape to be a transpose of the input";
|
<< "expected result shape to be a transpose of the input";
|
||||||
}
|
}
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
|
@ -75,8 +75,8 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||||
OpBuilder<(ins "double":$value)>
|
OpBuilder<(ins "double":$value)>
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this constant operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def AddOp : Toy_Op<"add", [NoSideEffect]> {
|
def AddOp : Toy_Op<"add", [NoSideEffect]> {
|
||||||
|
@ -225,8 +225,8 @@ def ReturnOp : Toy_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
|
||||||
bool hasOperand() { return getNumOperands() != 0; }
|
bool hasOperand() { return getNumOperands() != 0; }
|
||||||
}];
|
}];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this return operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
|
def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
|
||||||
|
@ -247,8 +247,8 @@ def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
|
||||||
OpBuilder<(ins "Value":$input)>
|
OpBuilder<(ins "Value":$input)>
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this transpose operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // TOY_OPS
|
#endif // TOY_OPS
|
||||||
|
|
|
@ -126,21 +126,20 @@ static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
|
||||||
printer << op.value();
|
printer << op.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
/// Verifier for the constant operation. This corresponds to the
|
||||||
/// in the op definition.
|
/// `let hasVerifier = 1` in the op definition.
|
||||||
static mlir::LogicalResult verify(ConstantOp op) {
|
mlir::LogicalResult ConstantOp::verify() {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that the rank of the attribute type matches the rank of the constant
|
// Check that the rank of the attribute type matches the rank of the constant
|
||||||
// result type.
|
// result type.
|
||||||
auto attrType = op.value().getType().cast<mlir::TensorType>();
|
auto attrType = value().getType().cast<mlir::TensorType>();
|
||||||
if (attrType.getRank() != resultType.getRank()) {
|
if (attrType.getRank() != resultType.getRank()) {
|
||||||
return op.emitOpError(
|
return emitOpError("return type must match the one of the attached value "
|
||||||
"return type must match the one of the attached value "
|
|
||||||
"attribute: ")
|
"attribute: ")
|
||||||
<< attrType.getRank() << " != " << resultType.getRank();
|
<< attrType.getRank() << " != " << resultType.getRank();
|
||||||
}
|
}
|
||||||
|
@ -148,7 +147,7 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
// Check that each of the dimensions match between the two types.
|
// Check that each of the dimensions match between the two types.
|
||||||
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
||||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||||
return op.emitOpError(
|
return emitOpError(
|
||||||
"return type shape mismatches its attribute at dimension ")
|
"return type shape mismatches its attribute at dimension ")
|
||||||
<< dim << ": " << attrType.getShape()[dim]
|
<< dim << ": " << attrType.getShape()[dim]
|
||||||
<< " != " << resultType.getShape()[dim];
|
<< " != " << resultType.getShape()[dim];
|
||||||
|
@ -190,28 +189,27 @@ void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReturnOp
|
// ReturnOp
|
||||||
|
|
||||||
static mlir::LogicalResult verify(ReturnOp op) {
|
mlir::LogicalResult ReturnOp::verify() {
|
||||||
// We know that the parent operation is a function, because of the 'HasParent'
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
// trait attached to the operation definition.
|
// trait attached to the operation definition.
|
||||||
auto function = cast<FuncOp>(op->getParentOp());
|
auto function = cast<FuncOp>((*this)->getParentOp());
|
||||||
|
|
||||||
/// ReturnOps can only have a single optional operand.
|
/// ReturnOps can only have a single optional operand.
|
||||||
if (op.getNumOperands() > 1)
|
if (getNumOperands() > 1)
|
||||||
return op.emitOpError() << "expects at most 1 return operand";
|
return emitOpError() << "expects at most 1 return operand";
|
||||||
|
|
||||||
// The operand number and types must match the function signature.
|
// The operand number and types must match the function signature.
|
||||||
const auto &results = function.getType().getResults();
|
const auto &results = function.getType().getResults();
|
||||||
if (op.getNumOperands() != results.size())
|
if (getNumOperands() != results.size())
|
||||||
return op.emitOpError()
|
return emitOpError() << "does not return the same number of values ("
|
||||||
<< "does not return the same number of values ("
|
<< getNumOperands() << ") as the enclosing function ("
|
||||||
<< op.getNumOperands() << ") as the enclosing function ("
|
|
||||||
<< results.size() << ")";
|
<< results.size() << ")";
|
||||||
|
|
||||||
// If the operation does not have an input, we are done.
|
// If the operation does not have an input, we are done.
|
||||||
if (!op.hasOperand())
|
if (!hasOperand())
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
auto inputType = *op.operand_type_begin();
|
auto inputType = *operand_type_begin();
|
||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
|
@ -219,9 +217,9 @@ static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
resultType.isa<mlir::UnrankedTensorType>())
|
resultType.isa<mlir::UnrankedTensorType>())
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
return op.emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
<< ") doesn't match function result type ("
|
<< ") doesn't match function result type (" << resultType
|
||||||
<< resultType << ")";
|
<< ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -233,16 +231,16 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||||
state.addOperands(value);
|
state.addOperands(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::LogicalResult verify(TransposeOp op) {
|
mlir::LogicalResult TransposeOp::verify() {
|
||||||
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
||||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
auto resultType = getType().dyn_cast<RankedTensorType>();
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
auto inputShape = inputType.getShape();
|
auto inputShape = inputType.getShape();
|
||||||
if (!std::equal(inputShape.begin(), inputShape.end(),
|
if (!std::equal(inputShape.begin(), inputShape.end(),
|
||||||
resultType.getShape().rbegin())) {
|
resultType.getShape().rbegin())) {
|
||||||
return op.emitError()
|
return emitError()
|
||||||
<< "expected result shape to be a transpose of the input";
|
<< "expected result shape to be a transpose of the input";
|
||||||
}
|
}
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
|
@ -78,8 +78,8 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||||
OpBuilder<(ins "double":$value)>
|
OpBuilder<(ins "double":$value)>
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this constant operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def AddOp : Toy_Op<"add",
|
def AddOp : Toy_Op<"add",
|
||||||
|
@ -252,8 +252,8 @@ def ReturnOp : Toy_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
|
||||||
bool hasOperand() { return getNumOperands() != 0; }
|
bool hasOperand() { return getNumOperands() != 0; }
|
||||||
}];
|
}];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this return operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TransposeOp : Toy_Op<"transpose",
|
def TransposeOp : Toy_Op<"transpose",
|
||||||
|
@ -275,8 +275,8 @@ def TransposeOp : Toy_Op<"transpose",
|
||||||
OpBuilder<(ins "Value":$input)>
|
OpBuilder<(ins "Value":$input)>
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this transpose operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // TOY_OPS
|
#endif // TOY_OPS
|
||||||
|
|
|
@ -182,21 +182,20 @@ static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
|
||||||
printer << op.value();
|
printer << op.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
/// Verifier for the constant operation. This corresponds to the
|
||||||
/// in the op definition.
|
/// `let hasVerifier = 1` in the op definition.
|
||||||
static mlir::LogicalResult verify(ConstantOp op) {
|
mlir::LogicalResult ConstantOp::verify() {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that the rank of the attribute type matches the rank of the constant
|
// Check that the rank of the attribute type matches the rank of the constant
|
||||||
// result type.
|
// result type.
|
||||||
auto attrType = op.value().getType().cast<mlir::TensorType>();
|
auto attrType = value().getType().cast<mlir::TensorType>();
|
||||||
if (attrType.getRank() != resultType.getRank()) {
|
if (attrType.getRank() != resultType.getRank()) {
|
||||||
return op.emitOpError(
|
return emitOpError("return type must match the one of the attached value "
|
||||||
"return type must match the one of the attached value "
|
|
||||||
"attribute: ")
|
"attribute: ")
|
||||||
<< attrType.getRank() << " != " << resultType.getRank();
|
<< attrType.getRank() << " != " << resultType.getRank();
|
||||||
}
|
}
|
||||||
|
@ -204,7 +203,7 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
// Check that each of the dimensions match between the two types.
|
// Check that each of the dimensions match between the two types.
|
||||||
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
||||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||||
return op.emitOpError(
|
return emitOpError(
|
||||||
"return type shape mismatches its attribute at dimension ")
|
"return type shape mismatches its attribute at dimension ")
|
||||||
<< dim << ": " << attrType.getShape()[dim]
|
<< dim << ": " << attrType.getShape()[dim]
|
||||||
<< " != " << resultType.getShape()[dim];
|
<< " != " << resultType.getShape()[dim];
|
||||||
|
@ -286,28 +285,27 @@ void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReturnOp
|
// ReturnOp
|
||||||
|
|
||||||
static mlir::LogicalResult verify(ReturnOp op) {
|
mlir::LogicalResult ReturnOp::verify() {
|
||||||
// We know that the parent operation is a function, because of the 'HasParent'
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
// trait attached to the operation definition.
|
// trait attached to the operation definition.
|
||||||
auto function = cast<FuncOp>(op->getParentOp());
|
auto function = cast<FuncOp>((*this)->getParentOp());
|
||||||
|
|
||||||
/// ReturnOps can only have a single optional operand.
|
/// ReturnOps can only have a single optional operand.
|
||||||
if (op.getNumOperands() > 1)
|
if (getNumOperands() > 1)
|
||||||
return op.emitOpError() << "expects at most 1 return operand";
|
return emitOpError() << "expects at most 1 return operand";
|
||||||
|
|
||||||
// The operand number and types must match the function signature.
|
// The operand number and types must match the function signature.
|
||||||
const auto &results = function.getType().getResults();
|
const auto &results = function.getType().getResults();
|
||||||
if (op.getNumOperands() != results.size())
|
if (getNumOperands() != results.size())
|
||||||
return op.emitOpError()
|
return emitOpError() << "does not return the same number of values ("
|
||||||
<< "does not return the same number of values ("
|
<< getNumOperands() << ") as the enclosing function ("
|
||||||
<< op.getNumOperands() << ") as the enclosing function ("
|
|
||||||
<< results.size() << ")";
|
<< results.size() << ")";
|
||||||
|
|
||||||
// If the operation does not have an input, we are done.
|
// If the operation does not have an input, we are done.
|
||||||
if (!op.hasOperand())
|
if (!hasOperand())
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
auto inputType = *op.operand_type_begin();
|
auto inputType = *operand_type_begin();
|
||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
|
@ -315,9 +313,9 @@ static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
resultType.isa<mlir::UnrankedTensorType>())
|
resultType.isa<mlir::UnrankedTensorType>())
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
return op.emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
<< ") doesn't match function result type ("
|
<< ") doesn't match function result type (" << resultType
|
||||||
<< resultType << ")";
|
<< ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -335,16 +333,16 @@ void TransposeOp::inferShapes() {
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::LogicalResult verify(TransposeOp op) {
|
mlir::LogicalResult TransposeOp::verify() {
|
||||||
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
||||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
auto resultType = getType().dyn_cast<RankedTensorType>();
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
auto inputShape = inputType.getShape();
|
auto inputShape = inputType.getShape();
|
||||||
if (!std::equal(inputShape.begin(), inputShape.end(),
|
if (!std::equal(inputShape.begin(), inputShape.end(),
|
||||||
resultType.getShape().rbegin())) {
|
resultType.getShape().rbegin())) {
|
||||||
return op.emitError()
|
return emitError()
|
||||||
<< "expected result shape to be a transpose of the input";
|
<< "expected result shape to be a transpose of the input";
|
||||||
}
|
}
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
|
@ -78,8 +78,8 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||||
OpBuilder<(ins "double":$value)>
|
OpBuilder<(ins "double":$value)>
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this constant operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def AddOp : Toy_Op<"add",
|
def AddOp : Toy_Op<"add",
|
||||||
|
@ -253,8 +253,8 @@ def ReturnOp : Toy_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
|
||||||
bool hasOperand() { return getNumOperands() != 0; }
|
bool hasOperand() { return getNumOperands() != 0; }
|
||||||
}];
|
}];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this return operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TransposeOp : Toy_Op<"transpose",
|
def TransposeOp : Toy_Op<"transpose",
|
||||||
|
@ -276,8 +276,8 @@ def TransposeOp : Toy_Op<"transpose",
|
||||||
OpBuilder<(ins "Value":$input)>
|
OpBuilder<(ins "Value":$input)>
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this transpose operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // TOY_OPS
|
#endif // TOY_OPS
|
||||||
|
|
|
@ -182,21 +182,20 @@ static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
|
||||||
printer << op.value();
|
printer << op.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
/// Verifier for the constant operation. This corresponds to the
|
||||||
/// in the op definition.
|
/// `let hasVerifier = 1` in the op definition.
|
||||||
static mlir::LogicalResult verify(ConstantOp op) {
|
mlir::LogicalResult ConstantOp::verify() {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that the rank of the attribute type matches the rank of the constant
|
// Check that the rank of the attribute type matches the rank of the constant
|
||||||
// result type.
|
// result type.
|
||||||
auto attrType = op.value().getType().cast<mlir::TensorType>();
|
auto attrType = value().getType().cast<mlir::TensorType>();
|
||||||
if (attrType.getRank() != resultType.getRank()) {
|
if (attrType.getRank() != resultType.getRank()) {
|
||||||
return op.emitOpError(
|
return emitOpError("return type must match the one of the attached value "
|
||||||
"return type must match the one of the attached value "
|
|
||||||
"attribute: ")
|
"attribute: ")
|
||||||
<< attrType.getRank() << " != " << resultType.getRank();
|
<< attrType.getRank() << " != " << resultType.getRank();
|
||||||
}
|
}
|
||||||
|
@ -204,7 +203,7 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
// Check that each of the dimensions match between the two types.
|
// Check that each of the dimensions match between the two types.
|
||||||
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
||||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||||
return op.emitOpError(
|
return emitOpError(
|
||||||
"return type shape mismatches its attribute at dimension ")
|
"return type shape mismatches its attribute at dimension ")
|
||||||
<< dim << ": " << attrType.getShape()[dim]
|
<< dim << ": " << attrType.getShape()[dim]
|
||||||
<< " != " << resultType.getShape()[dim];
|
<< " != " << resultType.getShape()[dim];
|
||||||
|
@ -286,28 +285,27 @@ void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReturnOp
|
// ReturnOp
|
||||||
|
|
||||||
static mlir::LogicalResult verify(ReturnOp op) {
|
mlir::LogicalResult ReturnOp::verify() {
|
||||||
// We know that the parent operation is a function, because of the 'HasParent'
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
// trait attached to the operation definition.
|
// trait attached to the operation definition.
|
||||||
auto function = cast<FuncOp>(op->getParentOp());
|
auto function = cast<FuncOp>((*this)->getParentOp());
|
||||||
|
|
||||||
/// ReturnOps can only have a single optional operand.
|
/// ReturnOps can only have a single optional operand.
|
||||||
if (op.getNumOperands() > 1)
|
if (getNumOperands() > 1)
|
||||||
return op.emitOpError() << "expects at most 1 return operand";
|
return emitOpError() << "expects at most 1 return operand";
|
||||||
|
|
||||||
// The operand number and types must match the function signature.
|
// The operand number and types must match the function signature.
|
||||||
const auto &results = function.getType().getResults();
|
const auto &results = function.getType().getResults();
|
||||||
if (op.getNumOperands() != results.size())
|
if (getNumOperands() != results.size())
|
||||||
return op.emitOpError()
|
return emitOpError() << "does not return the same number of values ("
|
||||||
<< "does not return the same number of values ("
|
<< getNumOperands() << ") as the enclosing function ("
|
||||||
<< op.getNumOperands() << ") as the enclosing function ("
|
|
||||||
<< results.size() << ")";
|
<< results.size() << ")";
|
||||||
|
|
||||||
// If the operation does not have an input, we are done.
|
// If the operation does not have an input, we are done.
|
||||||
if (!op.hasOperand())
|
if (!hasOperand())
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
auto inputType = *op.operand_type_begin();
|
auto inputType = *operand_type_begin();
|
||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
|
@ -315,9 +313,9 @@ static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
resultType.isa<mlir::UnrankedTensorType>())
|
resultType.isa<mlir::UnrankedTensorType>())
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
return op.emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
<< ") doesn't match function result type ("
|
<< ") doesn't match function result type (" << resultType
|
||||||
<< resultType << ")";
|
<< ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -335,16 +333,16 @@ void TransposeOp::inferShapes() {
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::LogicalResult verify(TransposeOp op) {
|
mlir::LogicalResult TransposeOp::verify() {
|
||||||
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
||||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
auto resultType = getType().dyn_cast<RankedTensorType>();
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
auto inputShape = inputType.getShape();
|
auto inputShape = inputType.getShape();
|
||||||
if (!std::equal(inputShape.begin(), inputShape.end(),
|
if (!std::equal(inputShape.begin(), inputShape.end(),
|
||||||
resultType.getShape().rbegin())) {
|
resultType.getShape().rbegin())) {
|
||||||
return op.emitError()
|
return emitError()
|
||||||
<< "expected result shape to be a transpose of the input";
|
<< "expected result shape to be a transpose of the input";
|
||||||
}
|
}
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
|
@ -78,8 +78,8 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||||
OpBuilder<(ins "double":$value)>
|
OpBuilder<(ins "double":$value)>
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this constant operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def AddOp : Toy_Op<"add",
|
def AddOp : Toy_Op<"add",
|
||||||
|
@ -253,8 +253,8 @@ def ReturnOp : Toy_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
|
||||||
bool hasOperand() { return getNumOperands() != 0; }
|
bool hasOperand() { return getNumOperands() != 0; }
|
||||||
}];
|
}];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this return operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TransposeOp : Toy_Op<"transpose",
|
def TransposeOp : Toy_Op<"transpose",
|
||||||
|
@ -276,8 +276,8 @@ def TransposeOp : Toy_Op<"transpose",
|
||||||
OpBuilder<(ins "Value":$input)>
|
OpBuilder<(ins "Value":$input)>
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this transpose operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // TOY_OPS
|
#endif // TOY_OPS
|
||||||
|
|
|
@ -182,21 +182,20 @@ static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
|
||||||
printer << op.value();
|
printer << op.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
/// Verifier for the constant operation. This corresponds to the
|
||||||
/// in the op definition.
|
/// `let hasVerifier = 1` in the op definition.
|
||||||
static mlir::LogicalResult verify(ConstantOp op) {
|
mlir::LogicalResult ConstantOp::verify() {
|
||||||
// If the return type of the constant is not an unranked tensor, the shape
|
// If the return type of the constant is not an unranked tensor, the shape
|
||||||
// must match the shape of the attribute holding the data.
|
// must match the shape of the attribute holding the data.
|
||||||
auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Check that the rank of the attribute type matches the rank of the constant
|
// Check that the rank of the attribute type matches the rank of the constant
|
||||||
// result type.
|
// result type.
|
||||||
auto attrType = op.value().getType().cast<mlir::TensorType>();
|
auto attrType = value().getType().cast<mlir::TensorType>();
|
||||||
if (attrType.getRank() != resultType.getRank()) {
|
if (attrType.getRank() != resultType.getRank()) {
|
||||||
return op.emitOpError(
|
return emitOpError("return type must match the one of the attached value "
|
||||||
"return type must match the one of the attached value "
|
|
||||||
"attribute: ")
|
"attribute: ")
|
||||||
<< attrType.getRank() << " != " << resultType.getRank();
|
<< attrType.getRank() << " != " << resultType.getRank();
|
||||||
}
|
}
|
||||||
|
@ -204,7 +203,7 @@ static mlir::LogicalResult verify(ConstantOp op) {
|
||||||
// Check that each of the dimensions match between the two types.
|
// Check that each of the dimensions match between the two types.
|
||||||
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
|
||||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||||
return op.emitOpError(
|
return emitOpError(
|
||||||
"return type shape mismatches its attribute at dimension ")
|
"return type shape mismatches its attribute at dimension ")
|
||||||
<< dim << ": " << attrType.getShape()[dim]
|
<< dim << ": " << attrType.getShape()[dim]
|
||||||
<< " != " << resultType.getShape()[dim];
|
<< " != " << resultType.getShape()[dim];
|
||||||
|
@ -286,28 +285,27 @@ void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReturnOp
|
// ReturnOp
|
||||||
|
|
||||||
static mlir::LogicalResult verify(ReturnOp op) {
|
mlir::LogicalResult ReturnOp::verify() {
|
||||||
// We know that the parent operation is a function, because of the 'HasParent'
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
// trait attached to the operation definition.
|
// trait attached to the operation definition.
|
||||||
auto function = cast<FuncOp>(op->getParentOp());
|
auto function = cast<FuncOp>((*this)->getParentOp());
|
||||||
|
|
||||||
/// ReturnOps can only have a single optional operand.
|
/// ReturnOps can only have a single optional operand.
|
||||||
if (op.getNumOperands() > 1)
|
if (getNumOperands() > 1)
|
||||||
return op.emitOpError() << "expects at most 1 return operand";
|
return emitOpError() << "expects at most 1 return operand";
|
||||||
|
|
||||||
// The operand number and types must match the function signature.
|
// The operand number and types must match the function signature.
|
||||||
const auto &results = function.getType().getResults();
|
const auto &results = function.getType().getResults();
|
||||||
if (op.getNumOperands() != results.size())
|
if (getNumOperands() != results.size())
|
||||||
return op.emitOpError()
|
return emitOpError() << "does not return the same number of values ("
|
||||||
<< "does not return the same number of values ("
|
<< getNumOperands() << ") as the enclosing function ("
|
||||||
<< op.getNumOperands() << ") as the enclosing function ("
|
|
||||||
<< results.size() << ")";
|
<< results.size() << ")";
|
||||||
|
|
||||||
// If the operation does not have an input, we are done.
|
// If the operation does not have an input, we are done.
|
||||||
if (!op.hasOperand())
|
if (!hasOperand())
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
auto inputType = *op.operand_type_begin();
|
auto inputType = *operand_type_begin();
|
||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
|
@ -315,9 +313,9 @@ static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
resultType.isa<mlir::UnrankedTensorType>())
|
resultType.isa<mlir::UnrankedTensorType>())
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
return op.emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
<< ") doesn't match function result type ("
|
<< ") doesn't match function result type (" << resultType
|
||||||
<< resultType << ")";
|
<< ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -335,16 +333,16 @@ void TransposeOp::inferShapes() {
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::LogicalResult verify(TransposeOp op) {
|
mlir::LogicalResult TransposeOp::verify() {
|
||||||
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
||||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
auto resultType = getType().dyn_cast<RankedTensorType>();
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
auto inputShape = inputType.getShape();
|
auto inputShape = inputType.getShape();
|
||||||
if (!std::equal(inputShape.begin(), inputShape.end(),
|
if (!std::equal(inputShape.begin(), inputShape.end(),
|
||||||
resultType.getShape().rbegin())) {
|
resultType.getShape().rbegin())) {
|
||||||
return op.emitError()
|
return emitError()
|
||||||
<< "expected result shape to be a transpose of the input";
|
<< "expected result shape to be a transpose of the input";
|
||||||
}
|
}
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
|
@ -94,8 +94,8 @@ def ConstantOp : Toy_Op<"constant",
|
||||||
OpBuilder<(ins "double":$value)>
|
OpBuilder<(ins "double":$value)>
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this constant operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
|
|
||||||
// Set the folder bit so that we can implement constant folders.
|
// Set the folder bit so that we can implement constant folders.
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
@ -273,8 +273,8 @@ def ReturnOp : Toy_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
|
||||||
bool hasOperand() { return getNumOperands() != 0; }
|
bool hasOperand() { return getNumOperands() != 0; }
|
||||||
}];
|
}];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this return operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def StructAccessOp : Toy_Op<"struct_access", [NoSideEffect]> {
|
def StructAccessOp : Toy_Op<"struct_access", [NoSideEffect]> {
|
||||||
|
@ -295,7 +295,8 @@ def StructAccessOp : Toy_Op<"struct_access", [NoSideEffect]> {
|
||||||
OpBuilder<(ins "Value":$input, "size_t":$index)>
|
OpBuilder<(ins "Value":$input, "size_t":$index)>
|
||||||
];
|
];
|
||||||
|
|
||||||
let verifier = [{ return ::verify(*this); }];
|
// Indicate that additional verification for this operation is necessary.
|
||||||
|
let hasVerifier = 1;
|
||||||
|
|
||||||
// Set the folder bit so that we can fold constant accesses.
|
// Set the folder bit so that we can fold constant accesses.
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
@ -320,7 +321,8 @@ def StructConstantOp : Toy_Op<"struct_constant", [ConstantLike, NoSideEffect]> {
|
||||||
|
|
||||||
let assemblyFormat = "$value attr-dict `:` type($output)";
|
let assemblyFormat = "$value attr-dict `:` type($output)";
|
||||||
|
|
||||||
let verifier = [{ return ::verify(*this); }];
|
// Indicate that additional verification for this operation is necessary.
|
||||||
|
let hasVerifier = 1;
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,8 +345,8 @@ def TransposeOp : Toy_Op<"transpose",
|
||||||
OpBuilder<(ins "Value":$input)>
|
OpBuilder<(ins "Value":$input)>
|
||||||
];
|
];
|
||||||
|
|
||||||
// Invoke a static verify method to verify this transpose operation.
|
// Indicate that additional verification for this operation is necessary.
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // TOY_OPS
|
#endif // TOY_OPS
|
||||||
|
|
|
@ -227,12 +227,12 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
|
||||||
|
|
||||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
||||||
/// in the op definition.
|
/// in the op definition.
|
||||||
static mlir::LogicalResult verify(ConstantOp op) {
|
mlir::LogicalResult ConstantOp::verify() {
|
||||||
return verifyConstantForType(op.getResult().getType(), op.value(), op);
|
return verifyConstantForType(getResult().getType(), value(), *this);
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::LogicalResult verify(StructConstantOp op) {
|
mlir::LogicalResult StructConstantOp::verify() {
|
||||||
return verifyConstantForType(op.getResult().getType(), op.value(), op);
|
return verifyConstantForType(getResult().getType(), value(), *this);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Infer the output shape of the ConstantOp, this is required by the shape
|
/// Infer the output shape of the ConstantOp, this is required by the shape
|
||||||
|
@ -312,28 +312,27 @@ void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReturnOp
|
// ReturnOp
|
||||||
|
|
||||||
static mlir::LogicalResult verify(ReturnOp op) {
|
mlir::LogicalResult ReturnOp::verify() {
|
||||||
// We know that the parent operation is a function, because of the 'HasParent'
|
// We know that the parent operation is a function, because of the 'HasParent'
|
||||||
// trait attached to the operation definition.
|
// trait attached to the operation definition.
|
||||||
auto function = cast<FuncOp>(op->getParentOp());
|
auto function = cast<FuncOp>((*this)->getParentOp());
|
||||||
|
|
||||||
/// ReturnOps can only have a single optional operand.
|
/// ReturnOps can only have a single optional operand.
|
||||||
if (op.getNumOperands() > 1)
|
if (getNumOperands() > 1)
|
||||||
return op.emitOpError() << "expects at most 1 return operand";
|
return emitOpError() << "expects at most 1 return operand";
|
||||||
|
|
||||||
// The operand number and types must match the function signature.
|
// The operand number and types must match the function signature.
|
||||||
const auto &results = function.getType().getResults();
|
const auto &results = function.getType().getResults();
|
||||||
if (op.getNumOperands() != results.size())
|
if (getNumOperands() != results.size())
|
||||||
return op.emitOpError()
|
return emitOpError() << "does not return the same number of values ("
|
||||||
<< "does not return the same number of values ("
|
<< getNumOperands() << ") as the enclosing function ("
|
||||||
<< op.getNumOperands() << ") as the enclosing function ("
|
|
||||||
<< results.size() << ")";
|
<< results.size() << ")";
|
||||||
|
|
||||||
// If the operation does not have an input, we are done.
|
// If the operation does not have an input, we are done.
|
||||||
if (!op.hasOperand())
|
if (!hasOperand())
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
auto inputType = *op.operand_type_begin();
|
auto inputType = *operand_type_begin();
|
||||||
auto resultType = results.front();
|
auto resultType = results.front();
|
||||||
|
|
||||||
// Check that the result type of the function matches the operand type.
|
// Check that the result type of the function matches the operand type.
|
||||||
|
@ -341,9 +340,9 @@ static mlir::LogicalResult verify(ReturnOp op) {
|
||||||
resultType.isa<mlir::UnrankedTensorType>())
|
resultType.isa<mlir::UnrankedTensorType>())
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
return op.emitError() << "type of return operand (" << inputType
|
return emitError() << "type of return operand (" << inputType
|
||||||
<< ") doesn't match function result type ("
|
<< ") doesn't match function result type (" << resultType
|
||||||
<< resultType << ")";
|
<< ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -360,15 +359,15 @@ void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
|
||||||
build(b, state, resultType, input, b.getI64IntegerAttr(index));
|
build(b, state, resultType, input, b.getI64IntegerAttr(index));
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::LogicalResult verify(StructAccessOp op) {
|
mlir::LogicalResult StructAccessOp::verify() {
|
||||||
StructType structTy = op.input().getType().cast<StructType>();
|
StructType structTy = input().getType().cast<StructType>();
|
||||||
size_t index = op.index();
|
size_t indexValue = index();
|
||||||
if (index >= structTy.getNumElementTypes())
|
if (indexValue >= structTy.getNumElementTypes())
|
||||||
return op.emitOpError()
|
return emitOpError()
|
||||||
<< "index should be within the range of the input struct type";
|
<< "index should be within the range of the input struct type";
|
||||||
mlir::Type resultType = op.getResult().getType();
|
mlir::Type resultType = getResult().getType();
|
||||||
if (resultType != structTy.getElementTypes()[index])
|
if (resultType != structTy.getElementTypes()[indexValue])
|
||||||
return op.emitOpError() << "must have the same result type as the struct "
|
return emitOpError() << "must have the same result type as the struct "
|
||||||
"element referred to by the index";
|
"element referred to by the index";
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
|
@ -388,16 +387,16 @@ void TransposeOp::inferShapes() {
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::LogicalResult verify(TransposeOp op) {
|
mlir::LogicalResult TransposeOp::verify() {
|
||||||
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
||||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
auto resultType = getType().dyn_cast<RankedTensorType>();
|
||||||
if (!inputType || !resultType)
|
if (!inputType || !resultType)
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
||||||
auto inputShape = inputType.getShape();
|
auto inputShape = inputType.getShape();
|
||||||
if (!std::equal(inputShape.begin(), inputShape.end(),
|
if (!std::equal(inputShape.begin(), inputShape.end(),
|
||||||
resultType.getShape().rbegin())) {
|
resultType.getShape().rbegin())) {
|
||||||
return op.emitError()
|
return emitError()
|
||||||
<< "expected result shape to be a transpose of the input";
|
<< "expected result shape to be a transpose of the input";
|
||||||
}
|
}
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
|
|
Loading…
Reference in New Issue