forked from OSchip/llvm-project
[mlir:Standard] Remove support for creating a `unit` ConstantOp
This is completely unused upstream, and does not really have well defined semantics on what this is supposed to do/how this fits into the ecosystem. Given that, as part of splitting up the standard dialect it's best to just remove this behavior, instead of try to awkwardly fit it somewhere upstream. Downstream users are encouraged to define their own operations that clearly can define the semantics of this. This also uncovered several lingering uses of ConstantOp that weren't updated to use arith::ConstantOp, and worked during conversions because the constant was removed/converted into something else before verification. See https://llvm.discourse.group/t/standard-dialect-the-final-chapter/ for more discussion. Differential Revision: https://reviews.llvm.org/D118654
This commit is contained in:
parent
ead1107257
commit
8e123ca65f
|
@ -72,7 +72,7 @@ LLVM_ATTRIBUTE_UNUSED static bool needToMaterialize(mlir::Value str) {
|
|||
/// Unwrap integer constant from mlir::Value.
|
||||
static llvm::Optional<std::int64_t> getIntIfConstant(mlir::Value value) {
|
||||
if (auto *definingOp = value.getDefiningOp())
|
||||
if (auto cst = mlir::dyn_cast<mlir::ConstantOp>(definingOp))
|
||||
if (auto cst = mlir::dyn_cast<mlir::arith::ConstantOp>(definingOp))
|
||||
if (auto intAttr = cst.getValue().dyn_cast<mlir::IntegerAttr>())
|
||||
return intAttr.getInt();
|
||||
return {};
|
||||
|
|
|
@ -376,23 +376,16 @@ def ConstantOp : Std_Op<"constant",
|
|||
operation ::= ssa-id `=` `std.constant` attribute-value `:` type
|
||||
```
|
||||
|
||||
The `constant` operation produces an SSA value equal to some constant
|
||||
specified by an attribute. This is the way that MLIR uses to form simple
|
||||
integer and floating point constants, as well as more exotic things like
|
||||
references to functions and tensor/vector constants.
|
||||
The `constant` operation produces an SSA value from a symbol reference to a
|
||||
`builtin.func` operation
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Complex constant
|
||||
%1 = constant [1.0 : f32, 1.0 : f32] : complex<f32>
|
||||
|
||||
// Reference to function @myfn.
|
||||
%2 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32>
|
||||
|
||||
// Equivalent generic forms
|
||||
%1 = "std.constant"() {value = [1.0 : f32, 1.0 : f32] : complex<f32>}
|
||||
: () -> complex<f32>
|
||||
%2 = "std.constant"() {value = @myfn}
|
||||
: () -> ((tensor<16xf32>, f32) -> tensor<16xf32>)
|
||||
```
|
||||
|
@ -403,15 +396,9 @@ def ConstantOp : Std_Op<"constant",
|
|||
([rationale](../Rationale/Rationale.md#multithreading-the-compiler)).
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyAttr:$value);
|
||||
let arguments = (ins FlatSymbolRefAttr:$value);
|
||||
let results = (outs AnyType);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Attribute":$value),
|
||||
[{ build($_builder, $_state, value.getType(), value); }]>,
|
||||
OpBuilder<(ins "Attribute":$value, "Type":$type),
|
||||
[{ build($_builder, $_state, type, value); }]>,
|
||||
];
|
||||
let assemblyFormat = "attr-dict $value `:` type(results)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns true if a constant operation can be built with the given value
|
||||
|
|
|
@ -435,31 +435,19 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
|
|||
LogicalResult
|
||||
matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// If constant refers to a function, convert it to "addressof".
|
||||
if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
|
||||
auto type = typeConverter->convertType(op.getResult().getType());
|
||||
if (!type || !LLVM::isCompatibleType(type))
|
||||
return rewriter.notifyMatchFailure(op, "failed to convert result type");
|
||||
auto type = typeConverter->convertType(op.getResult().getType());
|
||||
if (!type || !LLVM::isCompatibleType(type))
|
||||
return rewriter.notifyMatchFailure(op, "failed to convert result type");
|
||||
|
||||
auto newOp = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type,
|
||||
symbolRef.getValue());
|
||||
for (const NamedAttribute &attr : op->getAttrs()) {
|
||||
if (attr.getName().strref() == "value")
|
||||
continue;
|
||||
newOp->setAttr(attr.getName(), attr.getValue());
|
||||
}
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
return success();
|
||||
auto newOp =
|
||||
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
|
||||
for (const NamedAttribute &attr : op->getAttrs()) {
|
||||
if (attr.getName().strref() == "value")
|
||||
continue;
|
||||
newOp->setAttr(attr.getName(), attr.getValue());
|
||||
}
|
||||
|
||||
// Calling into other scopes (non-flat reference) is not supported in LLVM.
|
||||
if (op.getValue().isa<SymbolRefAttr>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "referring to a symbol outside of the current module");
|
||||
|
||||
return LLVM::detail::oneToOneRewrite(
|
||||
op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
|
||||
*getTypeConverter(), rewriter);
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -291,7 +291,7 @@ static ParallelComputeFunction createParallelComputeFunction(
|
|||
return llvm::to_vector(
|
||||
llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value {
|
||||
if (IntegerAttr attr = std::get<1>(tuple))
|
||||
return b.create<ConstantOp>(attr);
|
||||
return b.create<arith::ConstantOp>(attr);
|
||||
return std::get<0>(tuple);
|
||||
}));
|
||||
};
|
||||
|
|
|
@ -1576,7 +1576,7 @@ public:
|
|||
isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
|
||||
: DenseElementsAttr::get(outputType, intOutputValues);
|
||||
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -145,7 +145,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
|
|||
// Stitch results together into one large vector.
|
||||
Type resultEltType = results[0].getType().cast<VectorType>().getElementType();
|
||||
Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
|
||||
Value result = builder.create<ConstantOp>(
|
||||
Value result = builder.create<arith::ConstantOp>(
|
||||
resultExpandedType, builder.getZeroAttr(resultExpandedType));
|
||||
|
||||
for (int64_t i = 0; i < maxLinearIndex; ++i)
|
||||
|
|
|
@ -115,7 +115,10 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
|
|||
Location loc) {
|
||||
if (arith::ConstantOp::isBuildableWith(value, type))
|
||||
return builder.create<arith::ConstantOp>(loc, type, value);
|
||||
return builder.create<ConstantOp>(loc, type, value);
|
||||
if (ConstantOp::isBuildableWith(value, type))
|
||||
return builder.create<ConstantOp>(loc, type,
|
||||
value.cast<FlatSymbolRefAttr>());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -562,97 +565,35 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
|||
// ConstantOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, ConstantOp &op) {
|
||||
p << " ";
|
||||
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
|
||||
|
||||
if (op->getAttrs().size() > 1)
|
||||
p << ' ';
|
||||
p << op.getValue();
|
||||
|
||||
// If the value is a symbol reference, print a trailing type.
|
||||
if (op.getValue().isa<SymbolRefAttr>())
|
||||
p << " : " << op.getType();
|
||||
}
|
||||
|
||||
static ParseResult parseConstantOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
Attribute valueAttr;
|
||||
if (parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseAttribute(valueAttr, "value", result.attributes))
|
||||
return failure();
|
||||
|
||||
// If the attribute is a symbol reference, then we expect a trailing type.
|
||||
Type type;
|
||||
if (!valueAttr.isa<SymbolRefAttr>())
|
||||
type = valueAttr.getType();
|
||||
else if (parser.parseColonType(type))
|
||||
return failure();
|
||||
|
||||
// Add the attribute type to the list.
|
||||
return parser.addTypeToList(type, result.types);
|
||||
}
|
||||
|
||||
/// The constant op requires an attribute, and furthermore requires that it
|
||||
/// matches the return type.
|
||||
LogicalResult ConstantOp::verify() {
|
||||
auto value = getValue();
|
||||
if (!value)
|
||||
return emitOpError("requires a 'value' attribute");
|
||||
|
||||
StringRef fnName = getValue();
|
||||
Type type = getType();
|
||||
if (!value.getType().isa<NoneType>() && type != value.getType())
|
||||
return emitOpError() << "requires attribute's type (" << value.getType()
|
||||
<< ") to match op's return type (" << type << ")";
|
||||
|
||||
if (type.isa<FunctionType>()) {
|
||||
auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
|
||||
if (!fnAttr)
|
||||
return emitOpError("requires 'value' to be a function reference");
|
||||
// Try to find the referenced function.
|
||||
auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName);
|
||||
if (!fn)
|
||||
return emitOpError() << "reference to undefined function '" << fnName
|
||||
<< "'";
|
||||
|
||||
// Try to find the referenced function.
|
||||
auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||
fnAttr.getValue());
|
||||
if (!fn)
|
||||
return emitOpError() << "reference to undefined function '"
|
||||
<< fnAttr.getValue() << "'";
|
||||
// Check that the referenced function has the correct type.
|
||||
if (fn.getType() != type)
|
||||
return emitOpError("reference to function with mismatched type");
|
||||
|
||||
// Check that the referenced function has the correct type.
|
||||
if (fn.getType() != type)
|
||||
return emitOpError("reference to function with mismatched type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
if (type.isa<NoneType>() && value.isa<UnitAttr>())
|
||||
return success();
|
||||
|
||||
return emitOpError("unsupported 'value' attribute: ") << value;
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
return getValue();
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
void ConstantOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
Type type = getType();
|
||||
if (type.isa<FunctionType>()) {
|
||||
setNameFn(getResult(), "f");
|
||||
} else {
|
||||
setNameFn(getResult(), "cst");
|
||||
}
|
||||
setNameFn(getResult(), "f");
|
||||
}
|
||||
|
||||
/// Returns true if a constant operation can be built with the given value and
|
||||
/// result type.
|
||||
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
|
||||
// SymbolRefAttr can only be used with a function type.
|
||||
if (value.isa<SymbolRefAttr>())
|
||||
return type.isa<FunctionType>();
|
||||
// Otherwise, this must be a UnitAttr.
|
||||
return value.isa<UnitAttr>() && type.isa<NoneType>();
|
||||
return value.isa<FlatSymbolRefAttr>() && type.isa<FunctionType>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -307,7 +307,7 @@ struct TwoDimMultiReductionToReduction
|
|||
return failure();
|
||||
|
||||
auto loc = multiReductionOp.getLoc();
|
||||
Value result = rewriter.create<ConstantOp>(
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, multiReductionOp.getDestType(),
|
||||
rewriter.getZeroAttr(multiReductionOp.getDestType()));
|
||||
int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
|
||||
|
|
|
@ -232,7 +232,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
|
|||
static LogicalResult printOperation(CppEmitter &emitter,
|
||||
mlir::ConstantOp constantOp) {
|
||||
Operation *operation = constantOp.getOperation();
|
||||
Attribute value = constantOp.getValue();
|
||||
Attribute value = constantOp.getValueAttr();
|
||||
|
||||
return printConstantOp(emitter, operation, value);
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
|
||||
|
||||
func @unsupported_attribute() {
|
||||
// expected-error @+1 {{unsupported 'value' attribute: "" : index}}
|
||||
// expected-error @+1 {{invalid kind of attribute specified}}
|
||||
%0 = constant "" : index
|
||||
return
|
||||
}
|
||||
|
|
|
@ -99,9 +99,6 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
|
|||
// CHECK: %{{.*}} = arith.cmpf oeq, %{{.*}}, %{{.*}}: vector<4xf32>
|
||||
%70 = arith.cmpf oeq, %vcf32, %vcf32 : vector<4 x f32>
|
||||
|
||||
// CHECK: = constant unit
|
||||
%73 = constant unit
|
||||
|
||||
// CHECK: arith.constant true
|
||||
%74 = arith.constant true
|
||||
|
||||
|
|
|
@ -578,7 +578,7 @@ struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> {
|
|||
LogicalResult matchAndRewrite(ILLegalOpG op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
IntegerAttr attr = rewriter.getI32IntegerAttr(0);
|
||||
Value val = rewriter.create<ConstantOp>(op->getLoc(), attr);
|
||||
Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr);
|
||||
rewriter.replaceOpWithNewOp<LegalOpC>(op, val);
|
||||
return success();
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue