From 8e123ca65f5f9286e59f2c79184d01673c87aa42 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 31 Jan 2022 13:53:22 -0800 Subject: [PATCH] [mlir:Standard] Remove support for creating a `unit` ConstantOp This is completely unused upstream, and does not really have well defined semantics on what this is supposed to do/how this fits into the ecosystem. Given that, as part of splitting up the standard dialect it's best to just remove this behavior, instead of try to awkwardly fit it somewhere upstream. Downstream users are encouraged to define their own operations that clearly can define the semantics of this. This also uncovered several lingering uses of ConstantOp that weren't updated to use arith::ConstantOp, and worked during conversions because the constant was removed/converted into something else before verification. See https://llvm.discourse.group/t/standard-dialect-the-final-chapter/ for more discussion. Differential Revision: https://reviews.llvm.org/D118654 --- flang/lib/Optimizer/Builder/Character.cpp | 2 +- .../mlir/Dialect/StandardOps/IR/Ops.td | 21 +---- .../StandardToLLVM/StandardToLLVM.cpp | 34 +++---- .../Async/Transforms/AsyncParallelFor.cpp | 2 +- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 2 +- .../Transforms/PolynomialApproximation.cpp | 2 +- mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 93 ++++--------------- .../VectorMultiDimReductionTransforms.cpp | 2 +- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 2 +- mlir/test/Dialect/Standard/invalid.mlir | 2 +- mlir/test/IR/core-ops.mlir | 3 - mlir/test/lib/Dialect/Test/TestPatterns.cpp | 2 +- 12 files changed, 40 insertions(+), 127 deletions(-) diff --git a/flang/lib/Optimizer/Builder/Character.cpp b/flang/lib/Optimizer/Builder/Character.cpp index 87faa3b42c44..e4719133f3fa 100644 --- a/flang/lib/Optimizer/Builder/Character.cpp +++ b/flang/lib/Optimizer/Builder/Character.cpp @@ -72,7 +72,7 @@ LLVM_ATTRIBUTE_UNUSED static bool needToMaterialize(mlir::Value str) { /// Unwrap integer constant from mlir::Value. static llvm::Optional getIntIfConstant(mlir::Value value) { if (auto *definingOp = value.getDefiningOp()) - if (auto cst = mlir::dyn_cast(definingOp)) + if (auto cst = mlir::dyn_cast(definingOp)) if (auto intAttr = cst.getValue().dyn_cast()) return intAttr.getInt(); return {}; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index 9efe4ceb2147..2aca33eda3c4 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -376,23 +376,16 @@ def ConstantOp : Std_Op<"constant", operation ::= ssa-id `=` `std.constant` attribute-value `:` type ``` - The `constant` operation produces an SSA value equal to some constant - specified by an attribute. This is the way that MLIR uses to form simple - integer and floating point constants, as well as more exotic things like - references to functions and tensor/vector constants. + The `constant` operation produces an SSA value from a symbol reference to a + `builtin.func` operation Example: ```mlir - // Complex constant - %1 = constant [1.0 : f32, 1.0 : f32] : complex - // Reference to function @myfn. %2 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32> // Equivalent generic forms - %1 = "std.constant"() {value = [1.0 : f32, 1.0 : f32] : complex} - : () -> complex %2 = "std.constant"() {value = @myfn} : () -> ((tensor<16xf32>, f32) -> tensor<16xf32>) ``` @@ -403,15 +396,9 @@ def ConstantOp : Std_Op<"constant", ([rationale](../Rationale/Rationale.md#multithreading-the-compiler)). }]; - let arguments = (ins AnyAttr:$value); + let arguments = (ins FlatSymbolRefAttr:$value); let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "Attribute":$value), - [{ build($_builder, $_state, value.getType(), value); }]>, - OpBuilder<(ins "Attribute":$value, "Type":$type), - [{ build($_builder, $_state, type, value); }]>, - ]; + let assemblyFormat = "attr-dict $value `:` type(results)"; let extraClassDeclaration = [{ /// Returns true if a constant operation can be built with the given value diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index b2e18ab8196f..04c51422ed11 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -435,31 +435,19 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // If constant refers to a function, convert it to "addressof". - if (auto symbolRef = op.getValue().dyn_cast()) { - auto type = typeConverter->convertType(op.getResult().getType()); - if (!type || !LLVM::isCompatibleType(type)) - return rewriter.notifyMatchFailure(op, "failed to convert result type"); + auto type = typeConverter->convertType(op.getResult().getType()); + if (!type || !LLVM::isCompatibleType(type)) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); - auto newOp = rewriter.create(op.getLoc(), type, - symbolRef.getValue()); - for (const NamedAttribute &attr : op->getAttrs()) { - if (attr.getName().strref() == "value") - continue; - newOp->setAttr(attr.getName(), attr.getValue()); - } - rewriter.replaceOp(op, newOp->getResults()); - return success(); + auto newOp = + rewriter.create(op.getLoc(), type, op.getValue()); + for (const NamedAttribute &attr : op->getAttrs()) { + if (attr.getName().strref() == "value") + continue; + newOp->setAttr(attr.getName(), attr.getValue()); } - - // Calling into other scopes (non-flat reference) is not supported in LLVM. - if (op.getValue().isa()) - return rewriter.notifyMatchFailure( - op, "referring to a symbol outside of the current module"); - - return LLVM::detail::oneToOneRewrite( - op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), - *getTypeConverter(), rewriter); + rewriter.replaceOp(op, newOp->getResults()); + return success(); } }; diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index e1c91fbbc1d9..74d6d42e2b9b 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -291,7 +291,7 @@ static ParallelComputeFunction createParallelComputeFunction( return llvm::to_vector( llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value { if (IntegerAttr attr = std::get<1>(tuple)) - return b.create(attr); + return b.create(attr); return std::get<0>(tuple); })); }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 19e3c0318f57..32fd370012c4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1576,7 +1576,7 @@ public: isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) : DenseElementsAttr::get(outputType, intOutputValues); - rewriter.replaceOpWithNewOp(genericOp, outputAttr); + rewriter.replaceOpWithNewOp(genericOp, outputAttr); return success(); } diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index 1237f0b47cf7..d47d6ead0273 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -145,7 +145,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, // Stitch results together into one large vector. Type resultEltType = results[0].getType().cast().getElementType(); Type resultExpandedType = VectorType::get(expandedShape, resultEltType); - Value result = builder.create( + Value result = builder.create( resultExpandedType, builder.getZeroAttr(resultExpandedType)); for (int64_t i = 0; i < maxLinearIndex; ++i) diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 65e72293ed3f..bf35625adb62 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -115,7 +115,10 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder, Location loc) { if (arith::ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, value); - return builder.create(loc, type, value); + if (ConstantOp::isBuildableWith(value, type)) + return builder.create(loc, type, + value.cast()); + return nullptr; } //===----------------------------------------------------------------------===// @@ -562,97 +565,35 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { // ConstantOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, ConstantOp &op) { - p << " "; - p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); - - if (op->getAttrs().size() > 1) - p << ' '; - p << op.getValue(); - - // If the value is a symbol reference, print a trailing type. - if (op.getValue().isa()) - p << " : " << op.getType(); -} - -static ParseResult parseConstantOp(OpAsmParser &parser, - OperationState &result) { - Attribute valueAttr; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseAttribute(valueAttr, "value", result.attributes)) - return failure(); - - // If the attribute is a symbol reference, then we expect a trailing type. - Type type; - if (!valueAttr.isa()) - type = valueAttr.getType(); - else if (parser.parseColonType(type)) - return failure(); - - // Add the attribute type to the list. - return parser.addTypeToList(type, result.types); -} - -/// The constant op requires an attribute, and furthermore requires that it -/// matches the return type. LogicalResult ConstantOp::verify() { - auto value = getValue(); - if (!value) - return emitOpError("requires a 'value' attribute"); - + StringRef fnName = getValue(); Type type = getType(); - if (!value.getType().isa() && type != value.getType()) - return emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; - if (type.isa()) { - auto fnAttr = value.dyn_cast(); - if (!fnAttr) - return emitOpError("requires 'value' to be a function reference"); + // Try to find the referenced function. + auto fn = (*this)->getParentOfType().lookupSymbol(fnName); + if (!fn) + return emitOpError() << "reference to undefined function '" << fnName + << "'"; - // Try to find the referenced function. - auto fn = (*this)->getParentOfType().lookupSymbol( - fnAttr.getValue()); - if (!fn) - return emitOpError() << "reference to undefined function '" - << fnAttr.getValue() << "'"; + // Check that the referenced function has the correct type. + if (fn.getType() != type) + return emitOpError("reference to function with mismatched type"); - // Check that the referenced function has the correct type. - if (fn.getType() != type) - return emitOpError("reference to function with mismatched type"); - - return success(); - } - - if (type.isa() && value.isa()) - return success(); - - return emitOpError("unsupported 'value' attribute: ") << value; + return success(); } OpFoldResult ConstantOp::fold(ArrayRef operands) { assert(operands.empty() && "constant has no operands"); - return getValue(); + return getValueAttr(); } void ConstantOp::getAsmResultNames( function_ref setNameFn) { - Type type = getType(); - if (type.isa()) { - setNameFn(getResult(), "f"); - } else { - setNameFn(getResult(), "cst"); - } + setNameFn(getResult(), "f"); } -/// Returns true if a constant operation can be built with the given value and -/// result type. bool ConstantOp::isBuildableWith(Attribute value, Type type) { - // SymbolRefAttr can only be used with a function type. - if (value.isa()) - return type.isa(); - // Otherwise, this must be a UnitAttr. - return value.isa() && type.isa(); + return value.isa() && type.isa(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp index 495de25662db..52b52763b0dc 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp @@ -307,7 +307,7 @@ struct TwoDimMultiReductionToReduction return failure(); auto loc = multiReductionOp.getLoc(); - Value result = rewriter.create( + Value result = rewriter.create( loc, multiReductionOp.getDestType(), rewriter.getZeroAttr(multiReductionOp.getDestType())); int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index a332029a4cf8..5d7ef65fcad2 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -232,7 +232,7 @@ static LogicalResult printOperation(CppEmitter &emitter, static LogicalResult printOperation(CppEmitter &emitter, mlir::ConstantOp constantOp) { Operation *operation = constantOp.getOperation(); - Attribute value = constantOp.getValue(); + Attribute value = constantOp.getValueAttr(); return printConstantOp(emitter, operation, value); } diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir index 836158dd2160..e9359936a196 100644 --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -split-input-file %s -verify-diagnostics func @unsupported_attribute() { - // expected-error @+1 {{unsupported 'value' attribute: "" : index}} + // expected-error @+1 {{invalid kind of attribute specified}} %0 = constant "" : index return } diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index fefe7387f284..55280b2ac8b8 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -99,9 +99,6 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) { // CHECK: %{{.*}} = arith.cmpf oeq, %{{.*}}, %{{.*}}: vector<4xf32> %70 = arith.cmpf oeq, %vcf32, %vcf32 : vector<4 x f32> - // CHECK: = constant unit - %73 = constant unit - // CHECK: arith.constant true %74 = arith.constant true diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 34dd14176b45..53661511ee32 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -578,7 +578,7 @@ struct TestCreateUnregisteredOp : public OpRewritePattern { LogicalResult matchAndRewrite(ILLegalOpG op, PatternRewriter &rewriter) const final { IntegerAttr attr = rewriter.getI32IntegerAttr(0); - Value val = rewriter.create(op->getLoc(), attr); + Value val = rewriter.create(op->getLoc(), attr); rewriter.replaceOpWithNewOp(op, val); return success(); };