Add support for 1->N type mappings in the dialect conversion infrastructure. To support these mappings a hook must be overridden on the type converter: 'materializeConversion' :to generate a cast operation from the new types to the old type. This operation is automatically erased if all uses are removed, otherwise it remains in the IR for the user to handle.

PiperOrigin-RevId: 254411383
This commit is contained in:
River Riddle 2019-06-21 09:29:46 -07:00 committed by jpienaar
parent d080efefe0
commit 704a7fb13e
5 changed files with 143 additions and 24 deletions

View File

@ -216,6 +216,19 @@ public:
virtual LogicalResult convertSignatureArg(unsigned inputNo, Type type, virtual LogicalResult convertSignatureArg(unsigned inputNo, Type type,
NamedAttributeList attrs, NamedAttributeList attrs,
SignatureConversion &result); SignatureConversion &result);
/// This hook allows for materializing a conversion from a set of types into
/// one result type by generating a cast operation of some kind. The generated
/// operation should produce one result, of 'resultType', with the provided
/// 'inputs' as operands. This hook must be overridden when a type conversion
/// results in more than one type.
virtual Operation *materializeConversion(PatternRewriter &rewriter,
Type resultType,
ArrayRef<Value *> inputs,
Location loc) {
llvm_unreachable("expected 'materializeConversion' to be overridden when "
"generating 1->N type conversions");
}
}; };
/// This class describes a specific conversion target. /// This class describes a specific conversion target.

View File

@ -51,6 +51,12 @@ struct ArgConverter {
if (it == argMapping.end()) if (it == argMapping.end())
continue; continue;
for (auto *op : it->second) { for (auto *op : it->second) {
// If the operation exists within the parent block, like with 1->N cast
// operations, we don't need to drop them. They will be automatically
// cleaned up with the region is destroyed.
if (op->getBlock())
continue;
op->dropAllDefinedValueUses(); op->dropAllDefinedValueUses();
op->destroy(); op->destroy();
} }
@ -77,7 +83,13 @@ struct ArgConverter {
auto *op = argOps[i]; auto *op = argOps[i];
auto *arg = block->addArgument(op->getResult(0)->getType()); auto *arg = block->addArgument(op->getResult(0)->getType());
op->getResult(0)->replaceAllUsesWith(arg); op->getResult(0)->replaceAllUsesWith(arg);
op->destroy();
// If this was a 1->N value mapping it exists within the parent block so
// erase it instead of destroying.
if (op->getBlock())
op->erase();
else
op->destroy();
} }
} }
argMapping.clear(); argMapping.clear();
@ -97,8 +109,14 @@ struct ArgConverter {
auto *op = argOps[i]; auto *op = argOps[i];
// Handle the case of a 1->N value mapping. // Handle the case of a 1->N value mapping.
if (op->getNumOperands() > 1) if (op->getNumOperands() > 1) {
llvm_unreachable("1->N argument mappings are currently not handled"); // If all of the uses were removed, we can drop this op. Otherwise,
// keep the operation alive and let the user handle any remaining
// usages.
if (op->use_empty())
op->erase();
continue;
}
// Handle the case where this argument had a direct mapping. // Handle the case where this argument had a direct mapping.
if (op->getNumOperands() == 1) { if (op->getNumOperands() == 1) {
@ -132,7 +150,8 @@ struct ArgConverter {
} }
/// Converts the signature of the given entry block. /// Converts the signature of the given entry block.
void convertSignature(Block *block, void convertSignature(Block *block, PatternRewriter &rewriter,
TypeConverter &converter,
TypeConverter::SignatureConversion &signatureConversion, TypeConverter::SignatureConversion &signatureConversion,
BlockAndValueMapping &mapping) { BlockAndValueMapping &mapping) {
unsigned origArgCount = block->getNumArguments(); unsigned origArgCount = block->getNumArguments();
@ -146,13 +165,15 @@ struct ArgConverter {
// Remap each of the original arguments as determined by the signature // Remap each of the original arguments as determined by the signature
// conversion. // conversion.
auto &newArgMapping = argMapping[block]; auto &newArgMapping = argMapping[block];
rewriter.setInsertionPointToStart(block);
for (unsigned i = 0; i != origArgCount; ++i) { for (unsigned i = 0; i != origArgCount; ++i) {
ArrayRef<Value *> remappedValues; ArrayRef<Value *> remappedValues;
if (auto inputMap = signatureConversion.getInputMapping(i)) if (auto inputMap = signatureConversion.getInputMapping(i))
remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size); remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size);
BlockArgument *arg = block->getArgument(i); BlockArgument *arg = block->getArgument(i);
newArgMapping.push_back(convertArgument(arg, remappedValues, mapping)); newArgMapping.push_back(
convertArgument(arg, remappedValues, rewriter, converter, mapping));
} }
// Erase all of the original arguments. // Erase all of the original arguments.
@ -161,7 +182,8 @@ struct ArgConverter {
} }
/// Converts the arguments of the given block. /// Converts the arguments of the given block.
LogicalResult convertArguments(Block *block, TypeConverter &converter, LogicalResult convertArguments(Block *block, PatternRewriter &rewriter,
TypeConverter &converter,
BlockAndValueMapping &mapping) { BlockAndValueMapping &mapping) {
unsigned origArgCount = block->getNumArguments(); unsigned origArgCount = block->getNumArguments();
if (origArgCount == 0) if (origArgCount == 0)
@ -178,10 +200,11 @@ struct ArgConverter {
// Remap all of the original argument values. // Remap all of the original argument values.
auto &newArgMapping = argMapping[block]; auto &newArgMapping = argMapping[block];
rewriter.setInsertionPointToStart(block);
for (unsigned i = 0; i != origArgCount; ++i) { for (unsigned i = 0; i != origArgCount; ++i) {
SmallVector<Value *, 1> newArgs(block->addArguments(newArgTypes[i])); SmallVector<Value *, 1> newArgs(block->addArguments(newArgTypes[i]));
newArgMapping.push_back( newArgMapping.push_back(convertArgument(block->getArgument(i), newArgs,
convertArgument(block->getArgument(i), newArgs, mapping)); rewriter, converter, mapping));
} }
// Erase all of the original arguments. // Erase all of the original arguments.
@ -195,6 +218,8 @@ struct ArgConverter {
/// to perform the conversion. /// to perform the conversion.
Operation *convertArgument(BlockArgument *origArg, Operation *convertArgument(BlockArgument *origArg,
ArrayRef<Value *> newValues, ArrayRef<Value *> newValues,
PatternRewriter &rewriter,
TypeConverter &converter,
BlockAndValueMapping &mapping) { BlockAndValueMapping &mapping) {
// Handle the cases of 1->0 or 1->1 mappings. // Handle the cases of 1->0 or 1->1 mappings.
if (newValues.size() < 2) { if (newValues.size() < 2) {
@ -209,7 +234,15 @@ struct ArgConverter {
mapping.map(cast->getResult(0), newValues[0]); mapping.map(cast->getResult(0), newValues[0]);
return cast; return cast;
} }
llvm_unreachable("1->N argument mappings are currently not handled");
// Otherwise, this is a 1->N mapping. Call into the provided type converter
// to pack the new values.
auto *cast = converter.materializeConversion(rewriter, origArg->getType(),
newValues, loc);
assert(cast->getNumResults() == 1 &&
cast->getNumOperands() == newValues.size());
origArg->replaceAllUsesWith(cast->getResult(0));
return cast;
} }
/// A utility function used to create a conversion cast operation with the /// A utility function used to create a conversion cast operation with the
@ -874,10 +907,11 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
// types. // types.
if (typeConverter) { if (typeConverter) {
for (Block &block : for (Block &block :
llvm::drop_begin(region.getBlocks(), convertEntryTypes ? 0 : 1)) llvm::drop_begin(region.getBlocks(), convertEntryTypes ? 0 : 1)) {
if (failed(rewriter.argConverter.convertArguments(&block, *typeConverter, if (failed(rewriter.argConverter.convertArguments(
rewriter.mapping))) &block, rewriter, *typeConverter, rewriter.mapping)))
return failure(); return failure();
}
} }
// Store the number of blocks before conversion (new blocks may be added due // Store the number of blocks before conversion (new blocks may be added due
@ -909,8 +943,9 @@ LogicalResult FunctionConverter::convertFunction(
// Update the signature of the entry block. // Update the signature of the entry block.
if (signatureConversion) { if (signatureConversion) {
rewriter.argConverter.convertSignature( rewriter.argConverter.convertSignature(&f->getBody().front(), rewriter,
&f->getBody().front(), *signatureConversion, rewriter.mapping); *typeConverter, *signatureConversion,
rewriter.mapping);
} }
// Rewrite the function body. // Rewrite the function body.

View File

@ -227,4 +227,13 @@ def : Pat<(ILLegalOpD), (LegalOpA Test_LegalizerEnum_Failure)>;
def : Pat<(ILLegalOpC), (ILLegalOpE), [], (addBenefit 10)>; def : Pat<(ILLegalOpC), (ILLegalOpE), [], (addBenefit 10)>;
def : Pat<(ILLegalOpE), (LegalOpA Test_LegalizerEnum_Success)>; def : Pat<(ILLegalOpE), (LegalOpA Test_LegalizerEnum_Success)>;
//===----------------------------------------------------------------------===//
// Test Type Legalization
//===----------------------------------------------------------------------===//
def TestReturnOp : TEST_Op<"return", [Terminator]>,
Arguments<(ins Variadic<AnyType>:$inputs)>;
def TestCastOp : TEST_Op<"cast">,
Arguments<(ins Variadic<AnyType>:$inputs)>, Results<(outs AnyType:$res)>;
#endif // TEST_OPS #endif // TEST_OPS

View File

@ -49,6 +49,7 @@ static mlir::PassRegistration<TestPatternDriver>
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Legalization Driver. // Legalization Driver.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace { namespace {
/// This pattern is a simple pattern that inlines the first region of a given /// This pattern is a simple pattern that inlines the first region of a given
/// operation into the parent region. /// operation into the parent region.
@ -77,6 +78,29 @@ struct TestDropOp : public ConversionPattern {
return matchSuccess(); return matchSuccess();
} }
}; };
/// This pattern handles the case of a split return value.
struct TestSplitReturnType : public ConversionPattern {
TestSplitReturnType(MLIRContext *ctx)
: ConversionPattern("test.return", 1, ctx) {}
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const final {
// Check for a return of F32.
if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32())
return matchFailure();
// Check if the first operation is a cast operation, if it is we use the
// results directly.
auto *defOp = operands[0]->getDefiningOp();
if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
SmallVector<Value *, 2> returnOperands(packerOp.getOperands());
rewriter.replaceOpWithNewOp<TestReturnOp>(op, returnOperands);
return matchSuccess();
}
// Otherwise, fail to match.
return matchFailure();
}
};
} // namespace } // namespace
namespace { namespace {
@ -94,10 +118,35 @@ struct TestTypeConverter : public TypeConverter {
return success(); return success();
} }
// Split F32 into F16,F16.
if (t.isF32()) {
results.assign(2, FloatType::getF16(t.getContext()));
return success();
}
// Otherwise, convert the type directly. // Otherwise, convert the type directly.
results.push_back(t); results.push_back(t);
return success(); return success();
} }
/// Override the hook to materialize a conversion. This is necessary because
/// we generate 1->N type mappings.
Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
ArrayRef<Value *> inputs, Location loc) {
return rewriter.create<TestCastOp>(loc, resultType, inputs);
}
};
struct TestConversionTarget : public ConversionTarget {
TestConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
addLegalOp<LegalOpA>();
addDynamicallyLegalOp<TestReturnOp>();
}
bool isDynamicallyLegal(Operation *op) const final {
// Don't allow F32 operands.
return llvm::none_of(op->getOperandTypes(),
[](Type type) { return type.isF32(); });
}
}; };
struct TestLegalizePatternDriver struct TestLegalizePatternDriver
@ -105,12 +154,11 @@ struct TestLegalizePatternDriver
void runOnModule() override { void runOnModule() override {
mlir::OwningRewritePatternList patterns; mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns); populateWithGenerated(&getContext(), &patterns);
RewriteListBuilder<TestRegionRewriteBlockMovement, TestDropOp>::build( RewriteListBuilder<TestRegionRewriteBlockMovement, TestDropOp,
patterns, &getContext()); TestSplitReturnType>::build(patterns, &getContext());
TestTypeConverter converter; TestTypeConverter converter;
ConversionTarget target(getContext()); TestConversionTarget target(getContext());
target.addLegalOp<LegalOpA>();
if (failed(applyConversionPatterns(getModule(), target, converter, if (failed(applyConversionPatterns(getModule(), target, converter,
std::move(patterns)))) std::move(patterns))))
signalPassFailure(); signalPassFailure();

View File

@ -23,6 +23,19 @@ func @remap_input_1_to_1(%arg0: i64) -> i64 {
return %arg0 : i64 return %arg0 : i64
} }
// CHECK-LABEL: func @remap_input_1_to_N(%arg0: f16, %arg1: f16) -> (f16, f16)
func @remap_input_1_to_N(%arg0: f32) -> f32 {
// CHECK-NEXT: "test.return"(%arg0, %arg1) : (f16, f16) -> ()
"test.return"(%arg0) : (f32) -> ()
}
// CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16)
func @remap_input_1_to_N_remaining_use(%arg0: f32) {
// CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32
// CHECK-NEXT: "work"([[CAST]]) : (f32) -> ()
"work"(%arg0) : (f32) -> ()
}
// CHECK-LABEL: func @remap_multi(%arg0: f64, %arg1: f64) -> (f64, f64) // CHECK-LABEL: func @remap_multi(%arg0: f64, %arg1: f64) -> (f64, f64)
func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) { func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
// CHECK-NEXT: return %arg0, %arg1 : f64, f64 // CHECK-NEXT: return %arg0, %arg1 : f64, f64
@ -44,11 +57,12 @@ func @remap_nested() {
// CHECK-LABEL: func @remap_moved_region_args // CHECK-LABEL: func @remap_moved_region_args
func @remap_moved_region_args() { func @remap_moved_region_args() {
// CHECK-NEXT: return // CHECK-NEXT: return
// CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64): // CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
// CHECK-NEXT: "work"{{.*}} : (f64, f64) // CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
// CHECK-NEXT: "work"{{.*}} : (f64, f64, f32)
"test.region"() ({ "test.region"() ({
^bb1(%i0: i64, %unused: i16, %i1: i64): ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
"work"(%i0, %i1) : (i64, i64) -> () "work"(%i0, %i1, %2) : (i64, i64, f32) -> ()
}) : () -> () }) : () -> ()
return return
} }
@ -58,8 +72,8 @@ func @remap_drop_region() {
// CHECK-NEXT: return // CHECK-NEXT: return
// CHECK-NEXT: } // CHECK-NEXT: }
"test.drop_op"() ({ "test.drop_op"() ({
^bb1(%i0: i64, %unused: i16, %i1: i64): ^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
"work"(%i0, %i1) : (i64, i64) -> () "work"(%i0, %i1, %2) : (i64, i64, f32) -> ()
}) : () -> () }) : () -> ()
return return
} }