[mlir:PDLInterp] Refactor the implementation of result type inferrence

The current implementation uses a discrete "pdl_interp.inferred_types"
operation, which acts as a "fake" handle to a type range. This op is
used as a signal to pdl_interp.create_operation that types should be
inferred. This is terribly awkward and clunky though:

* This op doesn't have a byte code representation, and its conversion
  to bytecode kind of assumes that it is only used in a certain way. The
  current lowering is also broken and seemingly untested.

* Given that this is a different operation, it gives off the assumption
  that it can be used multiple times, or that after the first use
  the value contains the inferred types. This isn't the case though,
  the resultant type range can never actually be used as a type range.

This commit refactors the representation by removing the discrete
InferredTypesOp, and instead adds a UnitAttr to
pdl_interp.CreateOperation that signals when the created operations
should infer their types. This leads to a much much cleaner abstraction,
a more optimal bytecode lowering, and also allows for better error
handling and diagnostics when a created operation doesn't actually
support type inferrence.

Differential Revision: https://reviews.llvm.org/D124587
This commit is contained in:
River Riddle 2022-04-26 13:38:21 -07:00
parent 5387a38c38
commit 3c75228991
8 changed files with 248 additions and 107 deletions

View File

@ -409,14 +409,18 @@ def PDLInterp_CreateOperationOp
let description = [{
`pdl_interp.create_operation` operations create an `Operation` instance with
the specified attributes, operands, and result types. See `pdl.operation`
for a more detailed description on the interpretation of the arguments to
this operation.
for a more detailed description on the general interpretation of the arguments
to this operation.
Example:
```mlir
// Create an instance of a `foo.op` operation.
%op = pdl_interp.create_operation "foo.op"(%arg0 : !pdl.value) {"attrA" = %attr0} -> (%type : !pdl.type)
// Create an instance of a `foo.op` operation that has inferred result types
// (using the InferTypeOpInterface).
%op = pdl_interp.create_operation "foo.op"(%arg0 : !pdl.value) {"attrA" = %attr0} -> <inferred>
```
}];
@ -424,22 +428,26 @@ def PDLInterp_CreateOperationOp
Variadic<PDL_InstOrRangeOf<PDL_Value>>:$inputOperands,
Variadic<PDL_Attribute>:$inputAttributes,
StrArrayAttr:$inputAttributeNames,
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$inputResultTypes);
Variadic<PDL_InstOrRangeOf<PDL_Type>>:$inputResultTypes,
UnitAttr:$inferredResultTypes);
let results = (outs PDL_Operation:$resultOp);
let builders = [
OpBuilder<(ins "StringRef":$name, "ValueRange":$types,
"ValueRange":$operands, "ValueRange":$attributes,
"ArrayAttr":$attributeNames), [{
"bool":$inferredResultTypes, "ValueRange":$operands,
"ValueRange":$attributes, "ArrayAttr":$attributeNames), [{
build($_builder, $_state, $_builder.getType<pdl::OperationType>(), name,
operands, attributes, attributeNames, types);
operands, attributes, attributeNames, types, inferredResultTypes);
}]>
];
let assemblyFormat = [{
$name (`(` $inputOperands^ `:` type($inputOperands) `)`)?
$name (`(` $inputOperands^ `:` type($inputOperands) `)`)? ``
custom<CreateOperationOpAttributes>($inputAttributes, $inputAttributeNames)
(`->` `(` $inputResultTypes^ `:` type($inputResultTypes) `)`)? attr-dict
custom<CreateOperationOpResults>($inputResultTypes, type($inputResultTypes),
$inferredResultTypes)
attr-dict
}];
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -961,33 +969,6 @@ def PDLInterp_GetValueTypeOp : PDLInterp_Op<"get_value_type", [NoSideEffect,
];
}
//===----------------------------------------------------------------------===//
// pdl_interp::InferredTypesOp
//===----------------------------------------------------------------------===//
def PDLInterp_InferredTypesOp : PDLInterp_Op<"inferred_types"> {
let summary = "Generate a handle to a range of Types that are \"inferred\"";
let description = [{
`pdl_interp.inferred_types` operations generate handles to ranges of types
that should be inferred. This signals to other operations, such as
`pdl_interp.create_operation`, that these types should be inferred.
Example:
```mlir
%types = pdl_interp.inferred_types
```
}];
let results = (outs PDL_RangeOf<PDL_Type>:$result);
let assemblyFormat = "attr-dict";
let builders = [
OpBuilder<(ins), [{
build($_builder, $_state,
pdl::RangeType::get($_builder.getType<pdl::TypeType>()));
}]>
];
}
//===----------------------------------------------------------------------===//
// pdl_interp::IsNotNullOp
//===----------------------------------------------------------------------===//

View File

@ -100,11 +100,12 @@ private:
function_ref<Value(Value)> mapRewriteValue);
/// Generate the values used for resolving the result types of an operation
/// created within a dag rewriter region.
/// created within a dag rewriter region. If the result types of the operation
/// should be inferred, `hasInferredResultTypes` is set to true.
void generateOperationResultTypeRewriter(
pdl::OperationOp op, SmallVectorImpl<Value> &types,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
bool &hasInferredResultTypes);
/// A builder to use when generating interpreter operations.
OpBuilder builder;
@ -707,15 +708,16 @@ void PatternLowering::generateRewriter(
for (Value attr : operationOp.attributes())
attributes.push_back(mapRewriteValue(attr));
bool hasInferredResultTypes = false;
SmallVector<Value, 2> types;
generateOperationResultTypeRewriter(operationOp, types, rewriteValues,
mapRewriteValue);
generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
rewriteValues, hasInferredResultTypes);
// Create the new operation.
Location loc = operationOp.getLoc();
Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
loc, *operationOp.name(), types, operands, attributes,
operationOp.attributeNames());
loc, *operationOp.name(), types, hasInferredResultTypes, operands,
attributes, operationOp.attributeNames());
rewriteValues[operationOp.op()] = createdOp;
// Generate accesses for any results that have their types constrained.
@ -825,9 +827,9 @@ void PatternLowering::generateRewriter(
}
void PatternLowering::generateOperationResultTypeRewriter(
pdl::OperationOp op, SmallVectorImpl<Value> &types,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
bool &hasInferredResultTypes) {
// Look for an operation that was replaced by `op`. The result types will be
// inferred from the results that were replaced.
Block *rewriterBlock = op->getBlock();
@ -851,14 +853,11 @@ void PatternLowering::generateOperationResultTypeRewriter(
return;
}
// Check if the operation has type inference support.
if (op.hasTypeInference()) {
types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc()));
return;
}
// Otherwise, handle inference for each of the result types individually.
// Try to handle resolution for each of the result types individually. This is
// preferred over type inferrence because it will allow for us to use existing
// types directly, as opposed to trying to rebuild the type list.
OperandRange resultTypeValues = op.types();
auto tryResolveResultTypes = [&] {
types.reserve(resultTypeValues.size());
for (const auto &it : llvm::enumerate(resultTypeValues)) {
Value resultType = it.value();
@ -875,12 +874,33 @@ void PatternLowering::generateOperationResultTypeRewriter(
continue;
}
// Otherwise, we couldn't infer the result types. Bail out here to see if
// we can infer the types for this operation from another way.
types.clear();
return failure();
}
return success();
};
if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
return;
// Otherwise, check if the operation has type inference support itself.
if (op.hasTypeInference()) {
hasInferredResultTypes = true;
return;
}
// If the types could not be inferred from any context and there weren't any
// explicit result types, assume the user actually meant for the operation to
// have no results.
if (resultTypeValues.empty())
return;
// The verifier asserts that the result types of each pdl.operation can be
// inferred. If we reach here, there is a bug either in the logic above or
// in the verifier for pdl.operation.
op->emitOpError() << "unable to infer result type for operation";
llvm_unreachable("unable to infer result type for operation");
}
}
//===----------------------------------------------------------------------===//

View File

@ -47,6 +47,23 @@ static LogicalResult verifySwitchOp(OpT op) {
// pdl_interp::CreateOperationOp
//===----------------------------------------------------------------------===//
LogicalResult CreateOperationOp::verify() {
if (!getInferredResultTypes())
return success();
if (!getInputResultTypes().empty()) {
return emitOpError("with inferred results cannot also have "
"explicit result types");
}
OperationName opName(getName(), getContext());
if (!opName.hasInterface<InferTypeOpInterface>()) {
return emitOpError()
<< "has inferred results, but the created operation '" << opName
<< "' does not support result type inference (or is not "
"registered)";
}
return success();
}
static ParseResult parseCreateOperationOpAttributes(
OpAsmParser &p,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
@ -82,6 +99,41 @@ static void printCreateOperationOpAttributes(OpAsmPrinter &p,
p << '}';
}
static ParseResult parseCreateOperationOpResults(
OpAsmParser &p,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultOperands,
SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) {
if (failed(p.parseOptionalArrow()))
return success();
// Handle the case of inferred results.
if (succeeded(p.parseOptionalLess())) {
if (p.parseKeyword("inferred") || p.parseGreater())
return failure();
inferredResultTypes = p.getBuilder().getUnitAttr();
return success();
}
// Otherwise, parse the explicit results.
return failure(p.parseLParen() || p.parseOperandList(resultOperands) ||
p.parseColonTypeList(resultTypes) || p.parseRParen());
}
static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op,
OperandRange resultOperands,
TypeRange resultTypes,
UnitAttr inferredResultTypes) {
// Handle the case of inferred results.
if (inferredResultTypes) {
p << " -> <inferred>";
return;
}
// Otherwise, handle the explicit results.
if (!resultTypes.empty())
p << " -> (" << resultOperands << " : " << resultTypes << ")";
}
//===----------------------------------------------------------------------===//
// pdl_interp::ForEachOp
//===----------------------------------------------------------------------===//

View File

@ -162,6 +162,10 @@ enum OpCode : ByteCodeField {
};
} // namespace
/// A marker used to indicate if an operation should infer types.
static constexpr ByteCodeField kInferTypesMarker =
std::numeric_limits<ByteCodeField>::max();
//===----------------------------------------------------------------------===//
// ByteCode Generation
//===----------------------------------------------------------------------===//
@ -273,7 +277,6 @@ private:
void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
@ -723,8 +726,7 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
LLVM_DEBUG({
// The following list must contain all the operations that do not
// produce any bytecode.
if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp,
pdl_interp::InferredTypesOp>(op))
if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
writer.appendInline(op->getLoc());
});
TypeSwitch<Operation *>(op)
@ -742,11 +744,11 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
pdl_interp::SwitchResultCountOp>(
[&](auto interpOp) { this->generate(interpOp, writer); })
.Default([](Operation *) {
llvm_unreachable("unknown `pdl_interp` operation");
@ -847,6 +849,12 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
writer.append(static_cast<ByteCodeField>(attributes.size()));
for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
writer.append(std::get<0>(it), std::get<1>(it));
// Add the result types. If the operation has inferred results, we use a
// marker "size" value. Otherwise, we add the list of explicit result types.
if (op.getInferredResultTypes())
writer.append(kInferTypesMarker);
else
writer.appendPDLValueList(op.getInputResultTypes());
}
void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
@ -955,12 +963,6 @@ void Generator::generate(pdl_interp::GetValueTypeOp op,
writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
}
}
void Generator::generate(pdl_interp::InferredTypesOp op,
ByteCodeWriter &writer) {
// InferType maps to a null type as a marker for inferring result types.
getMemIndex(op.getResult()) = getMemIndex(Type());
}
void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
}
@ -1526,30 +1528,31 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
state.addAttribute(name, attr);
}
for (unsigned i = 0, e = read(); i != e; ++i) {
if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
state.types.push_back(read<Type>());
continue;
}
// If we find a null range, this signals that the types are infered.
if (TypeRange *resultTypes = read<TypeRange *>()) {
state.types.append(resultTypes->begin(), resultTypes->end());
continue;
}
// Handle the case where the operation has inferred types.
// Read in the result types. If the "size" is the sentinel value, this
// indicates that the result types should be inferred.
unsigned numResults = read();
if (numResults == kInferTypesMarker) {
InferTypeOpInterface::Concept *inferInterface =
state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
assert(inferInterface &&
"expected operation to provide InferTypeOpInterface");
// TODO: Handle failure.
state.types.clear();
if (failed(inferInterface->inferReturnTypes(
state.getContext(), state.location, state.operands,
state.attributes.getDictionary(state.getContext()), state.regions,
state.types)))
return;
break;
} else {
// Otherwise, this is a fixed number of results.
for (unsigned i = 0; i != numResults; ++i) {
if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
state.types.push_back(read<Type>());
} else {
TypeRange *resultTypes = read<TypeRange *>();
state.types.append(resultTypes->begin(), resultTypes->end());
}
}
}
Operation *resultOp = rewriter.create(state);

View File

@ -127,6 +127,29 @@ module @operation_infer_types_from_otherop_results {
// -----
// CHECK-LABEL: module @operation_infer_types_from_interface
module @operation_infer_types_from_interface {
// Unused operation that ensures the arithmetic dialect is loaded for use in the pattern.
arith.constant true
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter
// CHECK: %[[CST:.*]] = pdl_interp.create_operation "arith.constant" -> <inferred>
// CHECK: %[[CST_RES:.*]] = pdl_interp.get_results of %[[CST]] : !pdl.range<value>
// CHECK: %[[CST_TYPE:.*]] = pdl_interp.get_value_type of %[[CST_RES]] : !pdl.range<type>
// CHECK: pdl_interp.create_operation "foo.op" -> (%[[CST_TYPE]] : !pdl.range<type>)
pdl.pattern : benefit(1) {
%root = operation "foo.op"
rewrite %root {
%types = types
%newOp = operation "arith.constant" -> (%types : !pdl.range<type>)
%newOp2 = operation "foo.op" -> (%types : !pdl.range<type>)
}
}
}
// -----
// CHECK-LABEL: module @replace_with_op
module @replace_with_op {
// CHECK: module @rewriters

View File

@ -0,0 +1,26 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
//===----------------------------------------------------------------------===//
// pdl::CreateOperationOp
//===----------------------------------------------------------------------===//
pdl_interp.func @rewriter() {
// expected-error@+1 {{op has inferred results, but the created operation 'foo.op' does not support result type inference}}
%op = pdl_interp.create_operation "foo.op" -> <inferred>
pdl_interp.finalize
}
// -----
pdl_interp.func @rewriter() {
%type = pdl_interp.create_type i32
// expected-error@+1 {{op with inferred results cannot also have explicit result types}}
%op = "pdl_interp.create_operation"(%type) {
inferredResultTypes,
inputAttributeNames = [],
name = "foo.op",
operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>
} : (!pdl.type) -> (!pdl.operation)
pdl_interp.finalize
}

View File

@ -6,6 +6,10 @@
// -----
// Unused operation to force loading the `arithmetic` dialect for the
// test of type inferrence.
arith.constant true
func.func @operations(%attribute: !pdl.attribute,
%input: !pdl.value,
%type: !pdl.type) {
@ -21,6 +25,9 @@ func.func @operations(%attribute: !pdl.attribute,
// operands, and results
%op3 = pdl_interp.create_operation "foo.op"(%input : !pdl.value) -> (%type : !pdl.type)
// inferred results
%op4 = pdl_interp.create_operation "arith.constant" -> <inferred>
pdl_interp.finalize
}

View File

@ -531,6 +531,41 @@ module @ir attributes { test.check_types_1 } {
// pdl_interp::CreateOperationOp
//===----------------------------------------------------------------------===//
// Unused operation to force loading the `arithmetic` dialect for the
// test of type inferrence.
arith.constant 10
// Test support for inferring the types of an operation.
module @patterns {
pdl_interp.func @matcher(%root : !pdl.operation) {
pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
^pat:
pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
^end:
pdl_interp.finalize
}
module @rewriters {
pdl_interp.func @success(%root : !pdl.operation) {
%attr = pdl_interp.create_attribute true
%cst = pdl_interp.create_operation "arith.constant" {"value" = %attr} -> <inferred>
%cstResults = pdl_interp.get_results of %cst : !pdl.range<value>
%op = pdl_interp.create_operation "test.success"(%cstResults : !pdl.range<value>)
pdl_interp.erase %root
pdl_interp.finalize
}
}
}
// CHECK-LABEL: test.create_op_infer_results
// CHECK: %[[CST:.*]] = arith.constant true
// CHECK: "test.success"(%[[CST]])
module @ir attributes { test.create_op_infer_results } {
%results:2 = "test.op"() : () -> (i64, i64)
}
// -----
//===----------------------------------------------------------------------===//
@ -1181,12 +1216,6 @@ module @ir attributes { test.get_results_2 } {
// Fully tested within the tests for other operations.
//===----------------------------------------------------------------------===//
// pdl_interp::InferredTypesOp
//===----------------------------------------------------------------------===//
// Fully tested within the tests for other operations.
//===----------------------------------------------------------------------===//
// pdl_interp::IsNotNullOp
//===----------------------------------------------------------------------===//