[mlir:Standard] Remove support for creating a `unit` ConstantOp

This is completely unused upstream, and does not really have well defined semantics
on what this is supposed to do/how this fits into the ecosystem. Given that, as part of
splitting up the standard dialect it's best to just remove this behavior, instead of try
to awkwardly fit it somewhere upstream. Downstream users are encouraged to
define their own operations that clearly can define the semantics of this.

This also uncovered several lingering uses of ConstantOp that weren't
updated to use arith::ConstantOp, and worked during conversions because
the constant was removed/converted into something else before
verification.

See https://llvm.discourse.group/t/standard-dialect-the-final-chapter/ for more discussion.

Differential Revision: https://reviews.llvm.org/D118654
This commit is contained in:
River Riddle 2022-01-31 13:53:22 -08:00
parent ead1107257
commit 8e123ca65f
12 changed files with 40 additions and 127 deletions

View File

@ -72,7 +72,7 @@ LLVM_ATTRIBUTE_UNUSED static bool needToMaterialize(mlir::Value str) {
/// Unwrap integer constant from mlir::Value. /// Unwrap integer constant from mlir::Value.
static llvm::Optional<std::int64_t> getIntIfConstant(mlir::Value value) { static llvm::Optional<std::int64_t> getIntIfConstant(mlir::Value value) {
if (auto *definingOp = value.getDefiningOp()) if (auto *definingOp = value.getDefiningOp())
if (auto cst = mlir::dyn_cast<mlir::ConstantOp>(definingOp)) if (auto cst = mlir::dyn_cast<mlir::arith::ConstantOp>(definingOp))
if (auto intAttr = cst.getValue().dyn_cast<mlir::IntegerAttr>()) if (auto intAttr = cst.getValue().dyn_cast<mlir::IntegerAttr>())
return intAttr.getInt(); return intAttr.getInt();
return {}; return {};

View File

@ -376,23 +376,16 @@ def ConstantOp : Std_Op<"constant",
operation ::= ssa-id `=` `std.constant` attribute-value `:` type operation ::= ssa-id `=` `std.constant` attribute-value `:` type
``` ```
The `constant` operation produces an SSA value equal to some constant The `constant` operation produces an SSA value from a symbol reference to a
specified by an attribute. This is the way that MLIR uses to form simple `builtin.func` operation
integer and floating point constants, as well as more exotic things like
references to functions and tensor/vector constants.
Example: Example:
```mlir ```mlir
// Complex constant
%1 = constant [1.0 : f32, 1.0 : f32] : complex<f32>
// Reference to function @myfn. // Reference to function @myfn.
%2 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32> %2 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32>
// Equivalent generic forms // Equivalent generic forms
%1 = "std.constant"() {value = [1.0 : f32, 1.0 : f32] : complex<f32>}
: () -> complex<f32>
%2 = "std.constant"() {value = @myfn} %2 = "std.constant"() {value = @myfn}
: () -> ((tensor<16xf32>, f32) -> tensor<16xf32>) : () -> ((tensor<16xf32>, f32) -> tensor<16xf32>)
``` ```
@ -403,15 +396,9 @@ def ConstantOp : Std_Op<"constant",
([rationale](../Rationale/Rationale.md#multithreading-the-compiler)). ([rationale](../Rationale/Rationale.md#multithreading-the-compiler)).
}]; }];
let arguments = (ins AnyAttr:$value); let arguments = (ins FlatSymbolRefAttr:$value);
let results = (outs AnyType); let results = (outs AnyType);
let assemblyFormat = "attr-dict $value `:` type(results)";
let builders = [
OpBuilder<(ins "Attribute":$value),
[{ build($_builder, $_state, value.getType(), value); }]>,
OpBuilder<(ins "Attribute":$value, "Type":$type),
[{ build($_builder, $_state, type, value); }]>,
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
/// Returns true if a constant operation can be built with the given value /// Returns true if a constant operation can be built with the given value

View File

@ -435,14 +435,12 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
LogicalResult LogicalResult
matchAndRewrite(ConstantOp op, OpAdaptor adaptor, matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
// If constant refers to a function, convert it to "addressof".
if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
auto type = typeConverter->convertType(op.getResult().getType()); auto type = typeConverter->convertType(op.getResult().getType());
if (!type || !LLVM::isCompatibleType(type)) if (!type || !LLVM::isCompatibleType(type))
return rewriter.notifyMatchFailure(op, "failed to convert result type"); return rewriter.notifyMatchFailure(op, "failed to convert result type");
auto newOp = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, auto newOp =
symbolRef.getValue()); rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
for (const NamedAttribute &attr : op->getAttrs()) { for (const NamedAttribute &attr : op->getAttrs()) {
if (attr.getName().strref() == "value") if (attr.getName().strref() == "value")
continue; continue;
@ -451,16 +449,6 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
rewriter.replaceOp(op, newOp->getResults()); rewriter.replaceOp(op, newOp->getResults());
return success(); return success();
} }
// Calling into other scopes (non-flat reference) is not supported in LLVM.
if (op.getValue().isa<SymbolRefAttr>())
return rewriter.notifyMatchFailure(
op, "referring to a symbol outside of the current module");
return LLVM::detail::oneToOneRewrite(
op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
*getTypeConverter(), rewriter);
}
}; };
// A CallOp automatically promotes MemRefType to a sequence of alloca/store and // A CallOp automatically promotes MemRefType to a sequence of alloca/store and

View File

@ -291,7 +291,7 @@ static ParallelComputeFunction createParallelComputeFunction(
return llvm::to_vector( return llvm::to_vector(
llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value { llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value {
if (IntegerAttr attr = std::get<1>(tuple)) if (IntegerAttr attr = std::get<1>(tuple))
return b.create<ConstantOp>(attr); return b.create<arith::ConstantOp>(attr);
return std::get<0>(tuple); return std::get<0>(tuple);
})); }));
}; };

View File

@ -1576,7 +1576,7 @@ public:
isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
: DenseElementsAttr::get(outputType, intOutputValues); : DenseElementsAttr::get(outputType, intOutputValues);
rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr); rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
return success(); return success();
} }

View File

@ -145,7 +145,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
// Stitch results together into one large vector. // Stitch results together into one large vector.
Type resultEltType = results[0].getType().cast<VectorType>().getElementType(); Type resultEltType = results[0].getType().cast<VectorType>().getElementType();
Type resultExpandedType = VectorType::get(expandedShape, resultEltType); Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
Value result = builder.create<ConstantOp>( Value result = builder.create<arith::ConstantOp>(
resultExpandedType, builder.getZeroAttr(resultExpandedType)); resultExpandedType, builder.getZeroAttr(resultExpandedType));
for (int64_t i = 0; i < maxLinearIndex; ++i) for (int64_t i = 0; i < maxLinearIndex; ++i)

View File

@ -115,7 +115,10 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
Location loc) { Location loc) {
if (arith::ConstantOp::isBuildableWith(value, type)) if (arith::ConstantOp::isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, type, value); return builder.create<arith::ConstantOp>(loc, type, value);
return builder.create<ConstantOp>(loc, type, value); if (ConstantOp::isBuildableWith(value, type))
return builder.create<ConstantOp>(loc, type,
value.cast<FlatSymbolRefAttr>());
return nullptr;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -562,97 +565,35 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
// ConstantOp // ConstantOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, ConstantOp &op) {
p << " ";
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
if (op->getAttrs().size() > 1)
p << ' ';
p << op.getValue();
// If the value is a symbol reference, print a trailing type.
if (op.getValue().isa<SymbolRefAttr>())
p << " : " << op.getType();
}
static ParseResult parseConstantOp(OpAsmParser &parser,
OperationState &result) {
Attribute valueAttr;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(valueAttr, "value", result.attributes))
return failure();
// If the attribute is a symbol reference, then we expect a trailing type.
Type type;
if (!valueAttr.isa<SymbolRefAttr>())
type = valueAttr.getType();
else if (parser.parseColonType(type))
return failure();
// Add the attribute type to the list.
return parser.addTypeToList(type, result.types);
}
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
LogicalResult ConstantOp::verify() { LogicalResult ConstantOp::verify() {
auto value = getValue(); StringRef fnName = getValue();
if (!value)
return emitOpError("requires a 'value' attribute");
Type type = getType(); Type type = getType();
if (!value.getType().isa<NoneType>() && type != value.getType())
return emitOpError() << "requires attribute's type (" << value.getType()
<< ") to match op's return type (" << type << ")";
if (type.isa<FunctionType>()) {
auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
if (!fnAttr)
return emitOpError("requires 'value' to be a function reference");
// Try to find the referenced function. // Try to find the referenced function.
auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>( auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName);
fnAttr.getValue());
if (!fn) if (!fn)
return emitOpError() << "reference to undefined function '" return emitOpError() << "reference to undefined function '" << fnName
<< fnAttr.getValue() << "'"; << "'";
// Check that the referenced function has the correct type. // Check that the referenced function has the correct type.
if (fn.getType() != type) if (fn.getType() != type)
return emitOpError("reference to function with mismatched type"); return emitOpError("reference to function with mismatched type");
return success(); return success();
}
if (type.isa<NoneType>() && value.isa<UnitAttr>())
return success();
return emitOpError("unsupported 'value' attribute: ") << value;
} }
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands"); assert(operands.empty() && "constant has no operands");
return getValue(); return getValueAttr();
} }
void ConstantOp::getAsmResultNames( void ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { function_ref<void(Value, StringRef)> setNameFn) {
Type type = getType();
if (type.isa<FunctionType>()) {
setNameFn(getResult(), "f"); setNameFn(getResult(), "f");
} else {
setNameFn(getResult(), "cst");
}
} }
/// Returns true if a constant operation can be built with the given value and
/// result type.
bool ConstantOp::isBuildableWith(Attribute value, Type type) { bool ConstantOp::isBuildableWith(Attribute value, Type type) {
// SymbolRefAttr can only be used with a function type. return value.isa<FlatSymbolRefAttr>() && type.isa<FunctionType>();
if (value.isa<SymbolRefAttr>())
return type.isa<FunctionType>();
// Otherwise, this must be a UnitAttr.
return value.isa<UnitAttr>() && type.isa<NoneType>();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -307,7 +307,7 @@ struct TwoDimMultiReductionToReduction
return failure(); return failure();
auto loc = multiReductionOp.getLoc(); auto loc = multiReductionOp.getLoc();
Value result = rewriter.create<ConstantOp>( Value result = rewriter.create<arith::ConstantOp>(
loc, multiReductionOp.getDestType(), loc, multiReductionOp.getDestType(),
rewriter.getZeroAttr(multiReductionOp.getDestType())); rewriter.getZeroAttr(multiReductionOp.getDestType()));
int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];

View File

@ -232,7 +232,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
static LogicalResult printOperation(CppEmitter &emitter, static LogicalResult printOperation(CppEmitter &emitter,
mlir::ConstantOp constantOp) { mlir::ConstantOp constantOp) {
Operation *operation = constantOp.getOperation(); Operation *operation = constantOp.getOperation();
Attribute value = constantOp.getValue(); Attribute value = constantOp.getValueAttr();
return printConstantOp(emitter, operation, value); return printConstantOp(emitter, operation, value);
} }

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt -split-input-file %s -verify-diagnostics // RUN: mlir-opt -split-input-file %s -verify-diagnostics
func @unsupported_attribute() { func @unsupported_attribute() {
// expected-error @+1 {{unsupported 'value' attribute: "" : index}} // expected-error @+1 {{invalid kind of attribute specified}}
%0 = constant "" : index %0 = constant "" : index
return return
} }

View File

@ -99,9 +99,6 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
// CHECK: %{{.*}} = arith.cmpf oeq, %{{.*}}, %{{.*}}: vector<4xf32> // CHECK: %{{.*}} = arith.cmpf oeq, %{{.*}}, %{{.*}}: vector<4xf32>
%70 = arith.cmpf oeq, %vcf32, %vcf32 : vector<4 x f32> %70 = arith.cmpf oeq, %vcf32, %vcf32 : vector<4 x f32>
// CHECK: = constant unit
%73 = constant unit
// CHECK: arith.constant true // CHECK: arith.constant true
%74 = arith.constant true %74 = arith.constant true

View File

@ -578,7 +578,7 @@ struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
LogicalResult matchAndRewrite(ILLegalOpG op, LogicalResult matchAndRewrite(ILLegalOpG op,
PatternRewriter &rewriter) const final { PatternRewriter &rewriter) const final {
IntegerAttr attr = rewriter.getI32IntegerAttr(0); IntegerAttr attr = rewriter.getI32IntegerAttr(0);
Value val = rewriter.create<ConstantOp>(op->getLoc(), attr); Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
rewriter.replaceOpWithNewOp<LegalOpC>(op, val); rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
return success(); return success();
}; };