[mlir:Toy][NFC] Switch toy to use prefixed accessors

This commit is contained in:
River Riddle 2022-03-15 16:09:19 -07:00
parent 8ce3750ff6
commit ccfcfa9423
18 changed files with 63 additions and 91 deletions

View File

@ -23,6 +23,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def Toy_Dialect : Dialect {
let name = "toy";
let cppNamespace = "::mlir::toy";
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}
// Base class for toy dialect operations. This operation inherits from the base
@ -143,21 +144,15 @@ def FuncOp : Toy_Op<"func", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
ArrayRef<Type> getArgumentTypes() { return getType().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
ArrayRef<Type> getResultTypes() { return getType().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;

View File

@ -125,7 +125,7 @@ mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,
void ConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << value();
printer << getValue();
}
/// Verifier for the constant operation. This corresponds to the
@ -139,7 +139,7 @@ mlir::LogicalResult ConstantOp::verify() {
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = value().getType().cast<mlir::TensorType>();
auto attrType = getValue().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")

View File

@ -22,6 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def Toy_Dialect : Dialect {
let name = "toy";
let cppNamespace = "::mlir::toy";
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}
// Base class for toy dialect operations. This operation inherits from the base
@ -142,21 +143,15 @@ def FuncOp : Toy_Op<"func", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
ArrayRef<Type> getArgumentTypes() { return getType().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
ArrayRef<Type> getResultTypes() { return getType().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;

View File

@ -125,7 +125,7 @@ mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,
void ConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << value();
printer << getValue();
}
/// Verifier for the constant operation. This corresponds to the
@ -139,7 +139,7 @@ mlir::LogicalResult ConstantOp::verify() {
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = value().getType().cast<mlir::TensorType>();
auto attrType = getValue().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")

View File

@ -25,6 +25,7 @@ include "toy/ShapeInferenceInterface.td"
def Toy_Dialect : Dialect {
let name = "toy";
let cppNamespace = "::mlir::toy";
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}
// Base class for toy dialect operations. This operation inherits from the base
@ -172,21 +173,15 @@ def FuncOp : Toy_Op<"func", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
ArrayRef<Type> getArgumentTypes() { return getType().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
ArrayRef<Type> getResultTypes() { return getType().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;

View File

@ -187,7 +187,7 @@ mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,
void ConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << value();
printer << getValue();
}
/// Verifier for the constant operation. This corresponds to the
@ -201,7 +201,7 @@ mlir::LogicalResult ConstantOp::verify() {
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = value().getType().cast<mlir::TensorType>();
auto attrType = getValue().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@ -327,7 +327,7 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
//===----------------------------------------------------------------------===//
// MulOp

View File

@ -25,6 +25,7 @@ include "toy/ShapeInferenceInterface.td"
def Toy_Dialect : Dialect {
let name = "toy";
let cppNamespace = "::mlir::toy";
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}
// Base class for toy dialect operations. This operation inherits from the base
@ -172,21 +173,15 @@ def FuncOp : Toy_Op<"func", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
ArrayRef<Type> getArgumentTypes() { return getType().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
ArrayRef<Type> getResultTypes() { return getType().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;

View File

@ -187,7 +187,7 @@ mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,
void ConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << value();
printer << getValue();
}
/// Verifier for the constant operation. This corresponds to the
@ -201,7 +201,7 @@ mlir::LogicalResult ConstantOp::verify() {
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = value().getType().cast<mlir::TensorType>();
auto attrType = getValue().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@ -327,7 +327,7 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
//===----------------------------------------------------------------------===//
// MulOp

View File

@ -115,10 +115,10 @@ struct BinaryOpLowering : public ConversionPattern {
// Generate loads for the element of 'lhs' and 'rhs' at the inner
// loop.
auto loadedLhs =
builder.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
auto loadedRhs =
builder.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
auto loadedLhs = builder.create<AffineLoadOp>(
loc, binaryAdaptor.getLhs(), loopIvs);
auto loadedRhs = builder.create<AffineLoadOp>(
loc, binaryAdaptor.getRhs(), loopIvs);
// Create the binary operation performed on the loaded values.
return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
@ -138,7 +138,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
LogicalResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
DenseElementsAttr constantValue = op.getValue();
Location loc = op.getLoc();
// When lowering the constant operation, we allocate and assign the constant
@ -286,7 +286,7 @@ struct TransposeOpLowering : public ConversionPattern {
// TransposeOp. This allows for using the nice named
// accessors that are generated by the ODS.
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
Value input = transposeAdaptor.input();
Value input = transposeAdaptor.getInput();
// Transpose the elements by generating a load from the
// reverse indices.

View File

@ -25,6 +25,7 @@ include "toy/ShapeInferenceInterface.td"
def Toy_Dialect : Dialect {
let name = "toy";
let cppNamespace = "::mlir::toy";
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}
// Base class for toy dialect operations. This operation inherits from the base
@ -172,21 +173,15 @@ def FuncOp : Toy_Op<"func", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
ArrayRef<Type> getArgumentTypes() { return getType().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
ArrayRef<Type> getResultTypes() { return getType().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;

View File

@ -187,7 +187,7 @@ mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,
void ConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << value();
printer << getValue();
}
/// Verifier for the constant operation. This corresponds to the
@ -201,7 +201,7 @@ mlir::LogicalResult ConstantOp::verify() {
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = value().getType().cast<mlir::TensorType>();
auto attrType = getValue().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@ -327,7 +327,7 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
//===----------------------------------------------------------------------===//
// MulOp

View File

@ -115,10 +115,10 @@ struct BinaryOpLowering : public ConversionPattern {
// Generate loads for the element of 'lhs' and 'rhs' at the inner
// loop.
auto loadedLhs =
builder.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
auto loadedRhs =
builder.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
auto loadedLhs = builder.create<AffineLoadOp>(
loc, binaryAdaptor.getLhs(), loopIvs);
auto loadedRhs = builder.create<AffineLoadOp>(
loc, binaryAdaptor.getRhs(), loopIvs);
// Create the binary operation performed on the loaded values.
return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
@ -138,7 +138,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
LogicalResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
DenseElementsAttr constantValue = op.getValue();
Location loc = op.getLoc();
// When lowering the constant operation, we allocate and assign the constant
@ -286,7 +286,7 @@ struct TransposeOpLowering : public ConversionPattern {
// TransposeOp. This allows for using the nice named
// accessors that are generated by the ODS.
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
Value input = transposeAdaptor.input();
Value input = transposeAdaptor.getInput();
// Transpose the elements by generating a load from the
// reverse indices.

View File

@ -101,7 +101,7 @@ public:
// Generate a call to printf for the current element of the loop.
auto printOp = cast<toy::PrintOp>(op);
auto elementLoad =
rewriter.create<memref::LoadOp>(loc, printOp.input(), loopIvs);
rewriter.create<memref::LoadOp>(loc, printOp.getInput(), loopIvs);
rewriter.create<func::CallOp>(
loc, printfRef, rewriter.getIntegerType(32),
ArrayRef<Value>({formatSpecifierCst, elementLoad}));

View File

@ -25,6 +25,7 @@ include "toy/ShapeInferenceInterface.td"
def Toy_Dialect : Dialect {
let name = "toy";
let cppNamespace = "::mlir::toy";
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
// We set this bit to generate a declaration of the `materializeConstant`
// method so that we can materialize constants for our toy operations.
@ -191,21 +192,15 @@ def FuncOp : Toy_Op<"func", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)
>];
let extraClassDeclaration = [{
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return type().getInputs(); }
ArrayRef<Type> getArgumentTypes() { return getType().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return type().getResults(); }
ArrayRef<Type> getResultTypes() { return getType().getResults(); }
}];
let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;

View File

@ -174,7 +174,7 @@ mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser,
void ConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << value();
printer << getValue();
}
/// Verify that the given attribute value is valid for the given type.
@ -236,16 +236,16 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
/// in the op definition.
mlir::LogicalResult ConstantOp::verify() {
return verifyConstantForType(getResult().getType(), value(), *this);
return verifyConstantForType(getResult().getType(), getValue(), *this);
}
mlir::LogicalResult StructConstantOp::verify() {
return verifyConstantForType(getResult().getType(), value(), *this);
return verifyConstantForType(getResult().getType(), getValue(), *this);
}
/// Infer the output shape of the ConstantOp, this is required by the shape
/// inference interface.
void ConstantOp::inferShapes() { getResult().setType(value().getType()); }
void ConstantOp::inferShapes() { getResult().setType(getValue().getType()); }
//===----------------------------------------------------------------------===//
// AddOp
@ -354,7 +354,7 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
//===----------------------------------------------------------------------===//
// MulOp
@ -430,8 +430,8 @@ void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
}
mlir::LogicalResult StructAccessOp::verify() {
StructType structTy = input().getType().cast<StructType>();
size_t indexValue = index();
StructType structTy = getInput().getType().cast<StructType>();
size_t indexValue = getIndex();
if (indexValue >= structTy.getNumElementTypes())
return emitOpError()
<< "index should be within the range of the input struct type";

View File

@ -115,10 +115,10 @@ struct BinaryOpLowering : public ConversionPattern {
// Generate loads for the element of 'lhs' and 'rhs' at the inner
// loop.
auto loadedLhs =
builder.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
auto loadedRhs =
builder.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
auto loadedLhs = builder.create<AffineLoadOp>(
loc, binaryAdaptor.getLhs(), loopIvs);
auto loadedRhs = builder.create<AffineLoadOp>(
loc, binaryAdaptor.getRhs(), loopIvs);
// Create the binary operation performed on the loaded values.
return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
@ -138,7 +138,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
LogicalResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
DenseElementsAttr constantValue = op.getValue();
Location loc = op.getLoc();
// When lowering the constant operation, we allocate and assign the constant
@ -286,7 +286,7 @@ struct TransposeOpLowering : public ConversionPattern {
// TransposeOp. This allows for using the nice named
// accessors that are generated by the ODS.
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
Value input = transposeAdaptor.input();
Value input = transposeAdaptor.getInput();
// Transpose the elements by generating a load from the
// reverse indices.

View File

@ -101,7 +101,7 @@ public:
// Generate a call to printf for the current element of the loop.
auto printOp = cast<toy::PrintOp>(op);
auto elementLoad =
rewriter.create<memref::LoadOp>(loc, printOp.input(), loopIvs);
rewriter.create<memref::LoadOp>(loc, printOp.getInput(), loopIvs);
rewriter.create<func::CallOp>(
loc, printfRef, rewriter.getIntegerType(32),
ArrayRef<Value>({formatSpecifierCst, elementLoad}));

View File

@ -24,11 +24,13 @@ namespace {
} // namespace
/// Fold constants.
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
return getValue();
}
/// Fold struct constants.
OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
return value();
return getValue();
}
/// Fold simple struct access operations that access into a constant.
@ -37,7 +39,7 @@ OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
if (!structAttr)
return nullptr;
size_t elementIndex = index();
size_t elementIndex = getIndex();
return structAttr[elementIndex];
}